Add async chat completion

This commit is contained in:
Ethan 2024-12-27 14:45:39 -05:00
parent ba333b928e
commit 90b7511426
4 changed files with 52 additions and 25 deletions

View File

@ -1,34 +1,64 @@
use eframe::egui::{self, Response, RichText, Ui, Widget}; use eframe::egui::{self, Response, RichText, Ui, Widget};
use std::sync::mpsc::{Receiver, Sender};
use crate::groq::GROQ_CLIENT; use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT};
pub struct ModelResponse { pub struct ModelResponse {
pub name: String, pub name: String,
pub message: String, message: String,
pub tokens: i32, tokens: i32,
pub time: f32, time: f32,
loading: bool,
tx: Sender<ChatCompletionResponse>,
rx: Receiver<ChatCompletionResponse>,
} }
impl ModelResponse { impl ModelResponse {
pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self { pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self {
let (tx, rx) = std::sync::mpsc::channel();
Self { Self {
name, name,
message, message,
tokens, tokens,
time, time,
loading: false,
tx,
rx,
} }
} }
pub async fn chat_completion(&mut self, prompt: String) { pub fn start_chat_completion(&mut self, prompt: String) {
let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await; self.message = "Loading...".to_string();
self.tokens = 0;
self.time = 0.0;
self.loading = true;
Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone());
}
fn check_completion(&mut self) {
if let Ok(response) = self.rx.try_recv() {
self.message = response.choices[0].message.content.clone(); self.message = response.choices[0].message.content.clone();
self.tokens = response.usage.total_tokens; self.tokens = response.usage.total_tokens;
self.time = response.usage.total_time; self.time = response.usage.total_time;
self.loading = false;
} }
} }
impl Widget for &ModelResponse { fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender<ChatCompletionResponse>) {
tokio::spawn(async move {
let response = GROQ_CLIENT.chat_completion(model_name, prompt).await;
tx.send(response).expect("Failed to send response");
});
}
}
impl Widget for &mut ModelResponse {
fn ui(self, ui: &mut Ui) -> Response { fn ui(self, ui: &mut Ui) -> Response {
if self.loading {
self.check_completion();
}
ui.vertical(|ui| { ui.vertical(|ui| {
egui::Frame::none() egui::Frame::none()
.inner_margin(8.0) .inner_margin(8.0)
@ -59,6 +89,7 @@ impl Widget for &ModelResponse {
); );
}); });
}) })
}).response })
.response
} }
} }

View File

@ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> {
let label = ui.selectable_label(contained, model); let label = ui.selectable_label(contained, model);
if label.clicked() && !contained { if label.clicked() && !contained {
self.app_state.selected_models.push(ModelResponse { self.app_state.selected_models.push(ModelResponse::new(
name: model.clone(), model.clone(),
message: "No message".to_string(), "No message".to_string(),
tokens: 0, 0,
time: 0.0, 0.0,
}); ));
} }
} }
}); });

View File

@ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> {
ui.horizontal(|ui| { ui.horizontal(|ui| {
if ui.button("Submit").clicked() { if ui.button("Submit").clicked() {
futures::executor::block_on(self.app_state.handle_submission()); self.app_state.handle_submission();
} }
if ui.button("Clear").clicked() { if ui.button("Clear").clicked() {

View File

@ -20,14 +20,10 @@ impl AppState {
} }
} }
pub async fn handle_submission(&mut self) { pub fn handle_submission(&mut self) {
let mut completions = vec![];
for model in &mut self.selected_models { for model in &mut self.selected_models {
completions.push(model.chat_completion(self.prompt_input.clone())); model.start_chat_completion(self.prompt_input.clone());
} }
futures::future::join_all(completions).await;
} }
} }
@ -52,7 +48,7 @@ impl eframe::App for AppState {
ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| { ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| {
egui::ScrollArea::both().show(ui, |ui| { egui::ScrollArea::both().show(ui, |ui| {
ui.horizontal_top(|ui| { ui.horizontal_top(|ui| {
for model in &self.selected_models { for model in &mut self.selected_models {
ui.add_sized(widget_size, model); ui.add_sized(widget_size, model);
} }
}); });