Add async chat completion

This commit is contained in:
Ethan 2024-12-27 14:45:39 -05:00
parent ba333b928e
commit 90b7511426
4 changed files with 52 additions and 25 deletions

View File

@ -1,34 +1,64 @@
use eframe::egui::{self, Response, RichText, Ui, Widget};
use std::sync::mpsc::{Receiver, Sender};
use crate::groq::GROQ_CLIENT;
use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT};
pub struct ModelResponse {
pub name: String,
pub message: String,
pub tokens: i32,
pub time: f32,
message: String,
tokens: i32,
time: f32,
loading: bool,
tx: Sender<ChatCompletionResponse>,
rx: Receiver<ChatCompletionResponse>,
}
impl ModelResponse {
pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self {
let (tx, rx) = std::sync::mpsc::channel();
Self {
name,
message,
tokens,
time,
loading: false,
tx,
rx,
}
}
pub async fn chat_completion(&mut self, prompt: String) {
let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await;
self.message = response.choices[0].message.content.clone();
self.tokens = response.usage.total_tokens;
self.time = response.usage.total_time;
pub fn start_chat_completion(&mut self, prompt: String) {
self.message = "Loading...".to_string();
self.tokens = 0;
self.time = 0.0;
self.loading = true;
Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone());
}
fn check_completion(&mut self) {
if let Ok(response) = self.rx.try_recv() {
self.message = response.choices[0].message.content.clone();
self.tokens = response.usage.total_tokens;
self.time = response.usage.total_time;
self.loading = false;
}
}
fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender<ChatCompletionResponse>) {
tokio::spawn(async move {
let response = GROQ_CLIENT.chat_completion(model_name, prompt).await;
tx.send(response).expect("Failed to send response");
});
}
}
impl Widget for &ModelResponse {
impl Widget for &mut ModelResponse {
fn ui(self, ui: &mut Ui) -> Response {
if self.loading {
self.check_completion();
}
ui.vertical(|ui| {
egui::Frame::none()
.inner_margin(8.0)
@ -59,6 +89,7 @@ impl Widget for &ModelResponse {
);
});
})
}).response
})
.response
}
}

View File

@ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> {
let label = ui.selectable_label(contained, model);
if label.clicked() && !contained {
self.app_state.selected_models.push(ModelResponse {
name: model.clone(),
message: "No message".to_string(),
tokens: 0,
time: 0.0,
});
self.app_state.selected_models.push(ModelResponse::new(
model.clone(),
"No message".to_string(),
0,
0.0,
));
}
}
});

View File

@ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> {
ui.horizontal(|ui| {
if ui.button("Submit").clicked() {
futures::executor::block_on(self.app_state.handle_submission());
self.app_state.handle_submission();
}
if ui.button("Clear").clicked() {

View File

@ -20,14 +20,10 @@ impl AppState {
}
}
pub async fn handle_submission(&mut self) {
let mut completions = vec![];
pub fn handle_submission(&mut self) {
for model in &mut self.selected_models {
completions.push(model.chat_completion(self.prompt_input.clone()));
model.start_chat_completion(self.prompt_input.clone());
}
futures::future::join_all(completions).await;
}
}
@ -52,7 +48,7 @@ impl eframe::App for AppState {
ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| {
egui::ScrollArea::both().show(ui, |ui| {
ui.horizontal_top(|ui| {
for model in &self.selected_models {
for model in &mut self.selected_models {
ui.add_sized(widget_size, model);
}
});