Add async chat completion
This commit is contained in:
parent
ba333b928e
commit
90b7511426
@ -1,34 +1,64 @@
|
|||||||
use eframe::egui::{self, Response, RichText, Ui, Widget};
|
use eframe::egui::{self, Response, RichText, Ui, Widget};
|
||||||
|
use std::sync::mpsc::{Receiver, Sender};
|
||||||
|
|
||||||
use crate::groq::GROQ_CLIENT;
|
use crate::groq::{types::ChatCompletionResponse, GROQ_CLIENT};
|
||||||
|
|
||||||
pub struct ModelResponse {
|
pub struct ModelResponse {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub message: String,
|
message: String,
|
||||||
pub tokens: i32,
|
tokens: i32,
|
||||||
pub time: f32,
|
time: f32,
|
||||||
|
loading: bool,
|
||||||
|
tx: Sender<ChatCompletionResponse>,
|
||||||
|
rx: Receiver<ChatCompletionResponse>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelResponse {
|
impl ModelResponse {
|
||||||
pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self {
|
pub fn new(name: String, message: String, tokens: i32, time: f32) -> Self {
|
||||||
|
let (tx, rx) = std::sync::mpsc::channel();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
name,
|
name,
|
||||||
message,
|
message,
|
||||||
tokens,
|
tokens,
|
||||||
time,
|
time,
|
||||||
|
loading: false,
|
||||||
|
tx,
|
||||||
|
rx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn chat_completion(&mut self, prompt: String) {
|
pub fn start_chat_completion(&mut self, prompt: String) {
|
||||||
let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await;
|
self.message = "Loading...".to_string();
|
||||||
|
self.tokens = 0;
|
||||||
|
self.time = 0.0;
|
||||||
|
self.loading = true;
|
||||||
|
Self::spawn_chat_completion_task(prompt, self.name.clone(), self.tx.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_completion(&mut self) {
|
||||||
|
if let Ok(response) = self.rx.try_recv() {
|
||||||
self.message = response.choices[0].message.content.clone();
|
self.message = response.choices[0].message.content.clone();
|
||||||
self.tokens = response.usage.total_tokens;
|
self.tokens = response.usage.total_tokens;
|
||||||
self.time = response.usage.total_time;
|
self.time = response.usage.total_time;
|
||||||
|
self.loading = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Widget for &ModelResponse {
|
fn spawn_chat_completion_task(prompt: String, model_name: String, tx: Sender<ChatCompletionResponse>) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let response = GROQ_CLIENT.chat_completion(model_name, prompt).await;
|
||||||
|
tx.send(response).expect("Failed to send response");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Widget for &mut ModelResponse {
|
||||||
fn ui(self, ui: &mut Ui) -> Response {
|
fn ui(self, ui: &mut Ui) -> Response {
|
||||||
|
if self.loading {
|
||||||
|
self.check_completion();
|
||||||
|
}
|
||||||
|
|
||||||
ui.vertical(|ui| {
|
ui.vertical(|ui| {
|
||||||
egui::Frame::none()
|
egui::Frame::none()
|
||||||
.inner_margin(8.0)
|
.inner_margin(8.0)
|
||||||
@ -59,6 +89,7 @@ impl Widget for &ModelResponse {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
}).response
|
})
|
||||||
|
.response
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,12 +30,12 @@ impl<'a> Widget for ModelSelection<'a> {
|
|||||||
let label = ui.selectable_label(contained, model);
|
let label = ui.selectable_label(contained, model);
|
||||||
|
|
||||||
if label.clicked() && !contained {
|
if label.clicked() && !contained {
|
||||||
self.app_state.selected_models.push(ModelResponse {
|
self.app_state.selected_models.push(ModelResponse::new(
|
||||||
name: model.clone(),
|
model.clone(),
|
||||||
message: "No message".to_string(),
|
"No message".to_string(),
|
||||||
tokens: 0,
|
0,
|
||||||
time: 0.0,
|
0.0,
|
||||||
});
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -23,7 +23,7 @@ impl<'a> Widget for PromptInput<'a> {
|
|||||||
|
|
||||||
ui.horizontal(|ui| {
|
ui.horizontal(|ui| {
|
||||||
if ui.button("Submit").clicked() {
|
if ui.button("Submit").clicked() {
|
||||||
futures::executor::block_on(self.app_state.handle_submission());
|
self.app_state.handle_submission();
|
||||||
}
|
}
|
||||||
|
|
||||||
if ui.button("Clear").clicked() {
|
if ui.button("Clear").clicked() {
|
||||||
|
@ -20,14 +20,10 @@ impl AppState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_submission(&mut self) {
|
pub fn handle_submission(&mut self) {
|
||||||
let mut completions = vec![];
|
|
||||||
|
|
||||||
for model in &mut self.selected_models {
|
for model in &mut self.selected_models {
|
||||||
completions.push(model.chat_completion(self.prompt_input.clone()));
|
model.start_chat_completion(self.prompt_input.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
futures::future::join_all(completions).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,7 +48,7 @@ impl eframe::App for AppState {
|
|||||||
ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| {
|
ui.allocate_ui_with_layout(size, Layout::left_to_right(egui::Align::Min), |ui| {
|
||||||
egui::ScrollArea::both().show(ui, |ui| {
|
egui::ScrollArea::both().show(ui, |ui| {
|
||||||
ui.horizontal_top(|ui| {
|
ui.horizontal_top(|ui| {
|
||||||
for model in &self.selected_models {
|
for model in &mut self.selected_models {
|
||||||
ui.add_sized(widget_size, model);
|
ui.add_sized(widget_size, model);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user