Add async chat completion
This commit is contained in:
		| @ -1,34 +1,64 @@ | |||||||
| use eframe::egui::{self, Response, RichText, Ui, Widget}; | use eframe::egui::{self, Response, RichText, Ui, Widget}; | ||||||
|  | use std::sync::mpsc::{Receiver, Sender}; | ||||||
|  |  | ||||||
| use crate::groq::GROQ_CLIENT; | use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT}; | ||||||
|  |  | ||||||
| pub struct ModelResponse { | pub struct ModelResponse { | ||||||
|     pub name: String, |     pub name: String, | ||||||
|     pub message: String, |     message: String, | ||||||
|     pub tokens: i32, |     tokens: i32, | ||||||
|     pub time: f32, |     time: f32, | ||||||
|  |     loading: bool, | ||||||
|  |     tx: Sender<ChatCompletionResponse>, | ||||||
|  |     rx: Receiver<ChatCompletionResponse>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl ModelResponse { | impl ModelResponse { | ||||||
|     pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self { |     pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self { | ||||||
|  |         let (tx, rx) = std::sync::mpsc::channel(); | ||||||
|  |  | ||||||
|         Self { |         Self { | ||||||
|             name, |             name, | ||||||
|             message, |             message, | ||||||
|             tokens, |             tokens, | ||||||
|             time, |             time, | ||||||
|  |             loading: false, | ||||||
|  |             tx, | ||||||
|  |             rx, | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub async fn chat_completion(&mut self, prompt: String) { |     pub fn start_chat_completion(&mut self, prompt: String) { | ||||||
|         let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await; |         self.message = "Loading...".to_string(); | ||||||
|  |         self.tokens = 0; | ||||||
|  |         self.time = 0.0; | ||||||
|  |         self.loading = true; | ||||||
|  |         Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     fn check_completion(&mut self) { | ||||||
|  |         if let Ok(response) = self.rx.try_recv() { | ||||||
|             self.message = response.choices[0].message.content.clone(); |             self.message = response.choices[0].message.content.clone(); | ||||||
|             self.tokens = response.usage.total_tokens; |             self.tokens = response.usage.total_tokens; | ||||||
|             self.time = response.usage.total_time; |             self.time = response.usage.total_time; | ||||||
|  |             self.loading = false; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| impl Widget for &ModelResponse { |     fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender<ChatCompletionResponse>) { | ||||||
|  |         tokio::spawn(async move { | ||||||
|  |             let response = GROQ_CLIENT.chat_completion(model_name, prompt).await; | ||||||
|  |             tx.send(response).expect("Failed to send response"); | ||||||
|  |         }); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl Widget for &mut ModelResponse { | ||||||
|     fn ui(self, ui: &mut Ui) -> Response { |     fn ui(self, ui: &mut Ui) -> Response { | ||||||
|  |         if self.loading { | ||||||
|  |             self.check_completion(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         ui.vertical(|ui| { |         ui.vertical(|ui| { | ||||||
|             egui::Frame::none() |             egui::Frame::none() | ||||||
|                 .inner_margin(8.0) |                 .inner_margin(8.0) | ||||||
| @ -59,6 +89,7 @@ impl Widget for &ModelResponse { | |||||||
|                             ); |                             ); | ||||||
|                         }); |                         }); | ||||||
|                 }) |                 }) | ||||||
|         }).response |         }) | ||||||
|  |         .response | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  | |||||||
| @ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> { | |||||||
|                         let label = ui.selectable_label(contained, model); |                         let label = ui.selectable_label(contained, model); | ||||||
|  |  | ||||||
|                         if label.clicked() && !contained { |                         if label.clicked() && !contained { | ||||||
|                             self.app_state.selected_models.push(ModelResponse { |                             self.app_state.selected_models.push(ModelResponse::new( | ||||||
|                                 name: model.clone(), |                                 model.clone(), | ||||||
|                                 message: "No message".to_string(), |                                 "No message".to_string(), | ||||||
|                                 tokens: 0, |                                 0, | ||||||
|                                 time: 0.0, |                                 0.0, | ||||||
|                             }); |                             )); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 }); |                 }); | ||||||
|  | |||||||
| @ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> { | |||||||
|  |  | ||||||
|         ui.horizontal(|ui| { |         ui.horizontal(|ui| { | ||||||
|             if ui.button("Submit").clicked() { |             if ui.button("Submit").clicked() { | ||||||
|                 futures::executor::block_on(self.app_state.handle_submission()); |                 self.app_state.handle_submission(); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             if ui.button("Clear").clicked() { |             if ui.button("Clear").clicked() { | ||||||
|  | |||||||
| @ -20,14 +20,10 @@ impl AppState { | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     pub async fn handle_submission(&mut self) { |     pub fn handle_submission(&mut self) { | ||||||
|         let mut completions = vec![]; |  | ||||||
|  |  | ||||||
|         for model in &mut self.selected_models { |         for model in &mut self.selected_models { | ||||||
|             completions.push(model.chat_completion(self.prompt_input.clone())); |             model.start_chat_completion(self.prompt_input.clone()); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         futures::future::join_all(completions).await; |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -52,7 +48,7 @@ impl eframe::App for AppState { | |||||||
|             ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| { |             ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| { | ||||||
|                 egui::ScrollArea::both().show(ui, |ui| { |                 egui::ScrollArea::both().show(ui, |ui| { | ||||||
|                     ui.horizontal_top(|ui| { |                     ui.horizontal_top(|ui| { | ||||||
|                         for model in &self.selected_models { |                         for model in &mut self.selected_models { | ||||||
|                             ui.add_sized(widget_size, model); |                             ui.add_sized(widget_size, model); | ||||||
|                         } |                         } | ||||||
|                     }); |                     }); | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user