Add async chat completion
This commit is contained in:
parent
ba333b928e
commit
90b7511426
@ -1,34 +1,64 @@
|
||||
use eframe::egui::{self, Response, RichText, Ui, Widget};
|
||||
use std::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
use crate::groq::GROQ_CLIENT;
|
||||
use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT};
|
||||
|
||||
pub struct ModelResponse {
|
||||
pub name: String,
|
||||
pub message: String,
|
||||
pub tokens: i32,
|
||||
pub time: f32,
|
||||
message: String,
|
||||
tokens: i32,
|
||||
time: f32,
|
||||
loading: bool,
|
||||
tx: Sender<ChatCompletionResponse>,
|
||||
rx: Receiver<ChatCompletionResponse>,
|
||||
}
|
||||
|
||||
impl ModelResponse {
|
||||
pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self {
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
|
||||
Self {
|
||||
name,
|
||||
message,
|
||||
tokens,
|
||||
time,
|
||||
loading: false,
|
||||
tx,
|
||||
rx,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn chat_completion(&mut self, prompt: String) {
|
||||
let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await;
|
||||
self.message = response.choices[0].message.content.clone();
|
||||
self.tokens = response.usage.total_tokens;
|
||||
self.time = response.usage.total_time;
|
||||
pub fn start_chat_completion(&mut self, prompt: String) {
|
||||
self.message = "Loading...".to_string();
|
||||
self.tokens = 0;
|
||||
self.time = 0.0;
|
||||
self.loading = true;
|
||||
Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone());
|
||||
}
|
||||
|
||||
fn check_completion(&mut self) {
|
||||
if let Ok(response) = self.rx.try_recv() {
|
||||
self.message = response.choices[0].message.content.clone();
|
||||
self.tokens = response.usage.total_tokens;
|
||||
self.time = response.usage.total_time;
|
||||
self.loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender<ChatCompletionResponse>) {
|
||||
tokio::spawn(async move {
|
||||
let response = GROQ_CLIENT.chat_completion(model_name, prompt).await;
|
||||
tx.send(response).expect("Failed to send response");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Widget for &ModelResponse {
|
||||
impl Widget for &mut ModelResponse {
|
||||
fn ui(self, ui: &mut Ui) -> Response {
|
||||
if self.loading {
|
||||
self.check_completion();
|
||||
}
|
||||
|
||||
ui.vertical(|ui| {
|
||||
egui::Frame::none()
|
||||
.inner_margin(8.0)
|
||||
@ -59,6 +89,7 @@ impl Widget for &ModelResponse {
|
||||
);
|
||||
});
|
||||
})
|
||||
}).response
|
||||
})
|
||||
.response
|
||||
}
|
||||
}
|
||||
|
@ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> {
|
||||
let label = ui.selectable_label(contained, model);
|
||||
|
||||
if label.clicked() && !contained {
|
||||
self.app_state.selected_models.push(ModelResponse {
|
||||
name: model.clone(),
|
||||
message: "No message".to_string(),
|
||||
tokens: 0,
|
||||
time: 0.0,
|
||||
});
|
||||
self.app_state.selected_models.push(ModelResponse::new(
|
||||
model.clone(),
|
||||
"No message".to_string(),
|
||||
0,
|
||||
0.0,
|
||||
));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> {
|
||||
|
||||
ui.horizontal(|ui| {
|
||||
if ui.button("Submit").clicked() {
|
||||
futures::executor::block_on(self.app_state.handle_submission());
|
||||
self.app_state.handle_submission();
|
||||
}
|
||||
|
||||
if ui.button("Clear").clicked() {
|
||||
|
@ -20,14 +20,10 @@ impl AppState {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_submission(&mut self) {
|
||||
let mut completions = vec![];
|
||||
|
||||
pub fn handle_submission(&mut self) {
|
||||
for model in &mut self.selected_models {
|
||||
completions.push(model.chat_completion(self.prompt_input.clone()));
|
||||
model.start_chat_completion(self.prompt_input.clone());
|
||||
}
|
||||
|
||||
futures::future::join_all(completions).await;
|
||||
}
|
||||
}
|
||||
|
||||
@ -52,7 +48,7 @@ impl eframe::App for AppState {
|
||||
ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| {
|
||||
egui::ScrollArea::both().show(ui, |ui| {
|
||||
ui.horizontal_top(|ui| {
|
||||
for model in &self.selected_models {
|
||||
for model in &mut self.selected_models {
|
||||
ui.add_sized(widget_size, model);
|
||||
}
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user