diff --git a/rust-learning/src/gui/model_response.rs b/rust-learning/src/gui/model_response.rs index 36d5a11..1f2bb0e 100644 --- a/rust-learning/src/gui/model_response.rs +++ b/rust-learning/src/gui/model_response.rs @@ -1,34 +1,64 @@ use eframe::egui::{self, Response, RichText, Ui, Widget}; +use std::sync::mpsc::{Receiver, Sender}; -use crate::groq::GROQ_CLIENT; +use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT}; pub struct ModelResponse { pub name: String, - pub message: String, - pub tokens: i32, - pub time: f32, + message: String, + tokens: i32, + time: f32, + loading: bool, + tx: Sender, + rx: Receiver, } impl ModelResponse { pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self { + let (tx, rx) = std::sync::mpsc::channel(); + Self { name, message, tokens, time, + loading: false, + tx, + rx, } } - pub async fn chat_completion(&mut self, prompt: String) { - let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await; - self.message = response.choices[0].message.content.clone(); - self.tokens = response.usage.total_tokens; - self.time = response.usage.total_time; + pub fn start_chat_completion(&mut self, prompt: String) { + self.message = "Loading...".to_string(); + self.tokens = 0; + self.time = 0.0; + self.loading = true; + Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone()); + } + + fn check_completion(&mut self) { + if let Ok(response) = self.rx.try_recv() { + self.message = response.choices[0].message.content.clone(); + self.tokens = response.usage.total_tokens; + self.time = response.usage.total_time; + self.loading = false; + } + } + + fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender) { + tokio::spawn(async move { + let response = GROQ_CLIENT.chat_completion(model_name, prompt).await; + tx.send(response).expect("Failed to send response"); + }); } } -impl Widget for &ModelResponse { +impl Widget for &mut ModelResponse { fn ui(self, ui: &mut Ui) -> Response { + if self.loading { + self.check_completion(); + } + ui.vertical(|ui| { egui::Frame::none() .inner_margin(8.0) @@ -59,6 +89,7 @@ impl Widget for &ModelResponse { ); }); }) - }).response + }) + .response } } diff --git a/rust-learning/src/gui/model_selection.rs b/rust-learning/src/gui/model_selection.rs index 3d6b9c3..8aefc16 100644 --- a/rust-learning/src/gui/model_selection.rs +++ b/rust-learning/src/gui/model_selection.rs @@ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> { let label = ui.selectable_label(contained, model); if label.clicked() && !contained { - self.app_state.selected_models.push(ModelResponse { - name: model.clone(), - message: "No message".to_string(), - tokens: 0, - time: 0.0, - }); + self.app_state.selected_models.push(ModelResponse::new( + model.clone(), + "No message".to_string(), + 0, + 0.0, + )); } } }); diff --git a/rust-learning/src/gui/prompt_input.rs b/rust-learning/src/gui/prompt_input.rs index 7204c8e..3d257fa 100644 --- a/rust-learning/src/gui/prompt_input.rs +++ b/rust-learning/src/gui/prompt_input.rs @@ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> { ui.horizontal(|ui| { if ui.button("Submit").clicked() { - futures::executor::block_on(self.app_state.handle_submission()); + self.app_state.handle_submission(); } if ui.button("Clear").clicked() { diff --git a/rust-learning/src/gui/state.rs b/rust-learning/src/gui/state.rs index c32aea3..b6abf4c 100644 --- a/rust-learning/src/gui/state.rs +++ b/rust-learning/src/gui/state.rs @@ -20,14 +20,10 @@ impl AppState { } } - pub async fn handle_submission(&mut self) { - let mut completions = vec![]; - + pub fn handle_submission(&mut self) { for model in &mut self.selected_models { - completions.push(model.chat_completion(self.prompt_input.clone())); + model.start_chat_completion(self.prompt_input.clone()); } - - futures::future::join_all(completions).await; } } @@ -52,7 +48,7 @@ impl eframe::App for AppState { ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| { egui::ScrollArea::both().show(ui, |ui| { ui.horizontal_top(|ui| { - for model in &self.selected_models { + for model in &mut self.selected_models { ui.add_sized(widget_size, model); } });