From 809ab9a979012525061ff1eaaac6a219d9b7665b Mon Sep 17 00:00:00 2001 From: Ethan Date: Fri, 27 Dec 2024 02:27:22 -0500 Subject: [PATCH] Add API communication --- .env.example | 1 + .gitignore | 2 + rust-learning/Cargo.toml | 6 ++ rust-learning/src/config.rs | 18 ++++ rust-learning/src/groq.rs | 104 +++++++++++++++++++++++ rust-learning/src/groq/types.rs | 41 +++++++++ rust-learning/src/gui.rs | 27 +----- rust-learning/src/gui/model_response.rs | 19 +++-- rust-learning/src/gui/model_selection.rs | 14 ++- rust-learning/src/gui/prompt_input.rs | 17 ++-- rust-learning/src/gui/state.rs | 37 +++++++- rust-learning/src/main.rs | 13 ++- 12 files changed, 247 insertions(+), 52 deletions(-) create mode 100644 .env.example create mode 100644 rust-learning/src/config.rs create mode 100644 rust-learning/src/groq.rs create mode 100644 rust-learning/src/groq/types.rs diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..e90a1db --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +GROQ_API_KEY=your-api-key \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3ca43ae..09ca525 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb +# Environment Variables +.env diff --git a/rust-learning/Cargo.toml b/rust-learning/Cargo.toml index b77fc13..c34bfcf 100644 --- a/rust-learning/Cargo.toml +++ b/rust-learning/Cargo.toml @@ -4,4 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] +dotenv = "0.15.0" eframe = "0.30.0" +futures = "0.3.31" +reqwest = { version = "0.12", features = ["json"] } +serde = "1.0.216" +serde_json = "1.0.134" +tokio = { version = "1", features = ["full"] } diff --git a/rust-learning/src/config.rs b/rust-learning/src/config.rs new file mode 100644 index 0000000..2120861 --- /dev/null +++ b/rust-learning/src/config.rs @@ -0,0 +1,18 @@ +use dotenv::dotenv; +use std::sync::LazyLock; + +pub struct Config { + pub groq_api_key: String, +} + +impl Config { + pub fn new() -> Self { + dotenv().ok(); + + Self { + groq_api_key: std::env::var("GROQ_API_KEY").expect("GROQ_API_KEY must be set"), + } + } +} + +pub static CONFIG: LazyLock = LazyLock::new(|| Config::new()); diff --git a/rust-learning/src/groq.rs b/rust-learning/src/groq.rs new file mode 100644 index 0000000..8a08b38 --- /dev/null +++ b/rust-learning/src/groq.rs @@ -0,0 +1,104 @@ +use std::sync::LazyLock; + +use types::{ + ChatCompletionChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionUsage, ListModelsResponse, +}; + +use crate::config::CONFIG; + +pub mod types; + +pub struct Groq { + api_key: String, + url: String, + http_client: reqwest::Client, +} + +impl Groq { + pub fn new(api_key: String) -> Self { + Self { + api_key, + url: "https://api.groq.com/openai/v1".to_string(), + http_client: reqwest::Client::new(), + } + } + + pub async fn list_models(&self) -> Vec { + let response = self + .http_client + .get(format!("{}/models", self.url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .send() + .await; + + if response.is_err() { + print!("Request Error"); + return vec![]; + } + + let response_models = response.unwrap().json::().await; + + if response_models.is_err() { + print!("Response Parsing Error"); + return vec![]; + } + + let model_names = response_models + .unwrap() + .data + .iter() + .map(|model| model.id.clone()) + .collect(); + + model_names + } + + pub async fn chat_completion(&self, model: String, message: String) -> ChatCompletionResponse { + let mut error_response = ChatCompletionResponse { + model: "Error".to_string(), + choices: vec![ChatCompletionChoice { + message: ChatCompletionMessage { + role: "system".to_string(), + content: String::new(), + }, + }], + usage: ChatCompletionUsage { + total_tokens: 0, + total_time: 0.0, + }, + }; + + let request = ChatCompletionRequest { + model, + messages: vec![ChatCompletionMessage { + role: "user".to_string(), + content: message, + }], + }; + + let response = self + .http_client + .post(format!("{}/chat/completions", self.url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&request) + .send() + .await; + + if response.is_err() { + error_response.choices[0].message.content = "Request Error".to_string(); + return error_response; + } + + let response_completion = response.unwrap().json::().await; + + if response_completion.is_err() { + error_response.choices[0].message.content = "Response Parsing Error".to_string(); + return error_response; + } + + response_completion.unwrap() + } +} + +pub static GROQ_CLIENT: LazyLock = LazyLock::new(|| Groq::new(CONFIG.groq_api_key.clone())); diff --git a/rust-learning/src/groq/types.rs b/rust-learning/src/groq/types.rs new file mode 100644 index 0000000..413fedd --- /dev/null +++ b/rust-learning/src/groq/types.rs @@ -0,0 +1,41 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize)] +pub struct RetrieveModelResponse { + pub id: String, +} + +#[derive(Deserialize)] +pub struct ListModelsResponse { + pub data: Vec, +} + +#[derive(Serialize, Deserialize)] +pub struct ChatCompletionMessage { + pub role: String, + pub content: String, +} + +#[derive(Deserialize)] +pub struct ChatCompletionChoice { + pub message: ChatCompletionMessage, +} + +#[derive(Deserialize)] +pub struct ChatCompletionUsage { + pub total_tokens: i32, + pub total_time: f32, +} + +#[derive(Deserialize)] +pub struct ChatCompletionResponse { + pub model: String, + pub choices: Vec, + pub usage: ChatCompletionUsage, +} + +#[derive(Serialize)] +pub struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, +} diff --git a/rust-learning/src/gui.rs b/rust-learning/src/gui.rs index fd44618..d8d7be2 100644 --- a/rust-learning/src/gui.rs +++ b/rust-learning/src/gui.rs @@ -1,4 +1,3 @@ -use eframe::egui; use state::AppState; pub mod model_response; @@ -6,7 +5,7 @@ pub mod model_selection; pub mod prompt_input; pub mod state; -pub fn main() -> eframe::Result { +pub fn main(models_available: Vec) -> eframe::Result { let options = eframe::NativeOptions { viewport: eframe::egui::ViewportBuilder::default().with_inner_size([900.0, 600.0]), ..Default::default() @@ -15,28 +14,6 @@ pub fn main() -> eframe::Result { eframe::run_native( "Groq Model Comparison", options, - Box::new(|cc| Ok(Box::new(AppState::new(cc)))), + Box::new(|cc| Ok(Box::new(AppState::new(cc, models_available)))), ) } - -impl eframe::App for AppState { - fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { - egui::CentralPanel::default().show(ctx, |ui| { - ui.heading("Groq Model Comparison"); - ui.label("Compare Groq models with ease!"); - - ui.add(model_selection::ModelSelection::new( - self, - vec!["Model 1".to_string(), "Model 2".to_string()], - )); - - ui.add(prompt_input::PromptInput::new(self)); - - ui.horizontal(|ui| { - for model in &self.selected_models { - ui.vertical(|ui| ui.add(model)); - } - }); - }); - } -} diff --git a/rust-learning/src/gui/model_response.rs b/rust-learning/src/gui/model_response.rs index 19227ca..f5ada50 100644 --- a/rust-learning/src/gui/model_response.rs +++ b/rust-learning/src/gui/model_response.rs @@ -1,14 +1,16 @@ -use eframe::egui::{self, Response, Ui, Widget}; +use eframe::egui::{self, Response, RichText, Ui, Widget}; + +use crate::groq::GROQ_CLIENT; pub struct ModelResponse { pub name: String, pub message: String, pub status: i32, - pub time: i32, + pub time: f32, } impl ModelResponse { - pub fn new(name: String, message: String, status: i32, time: i32) -> Self { + pub fn new(name: String, message: String, status: i32, time: f32) -> Self { Self { name, message, @@ -16,6 +18,13 @@ impl ModelResponse { time, } } + + pub async fn chat_completion(&mut self, prompt: String) { + let response = GROQ_CLIENT.chat_completion(self.name.clone(), prompt).await; + self.message = response.choices[0].message.content.clone(); + self.status = 200; + self.time = response.usage.total_time; + } } impl Widget for &ModelResponse { @@ -25,9 +34,9 @@ impl Widget for &ModelResponse { .rounding(4.0) .fill(egui::Color32::DARK_GRAY) .show(ui, |ui| { - ui.horizontal(|ui| { - ui.label(&self.name); + ui.label(RichText::new(&self.name).strong()); + ui.horizontal(|ui| { ui.horizontal(|ui| { ui.label("Status:"); ui.label(self.status.to_string()); diff --git a/rust-learning/src/gui/model_selection.rs b/rust-learning/src/gui/model_selection.rs index 1a455c0..221f189 100644 --- a/rust-learning/src/gui/model_selection.rs +++ b/rust-learning/src/gui/model_selection.rs @@ -4,15 +4,11 @@ use super::{model_response::ModelResponse, state::AppState}; pub struct ModelSelection<'a> { app_state: &'a mut AppState, - pub models_available: Vec, } impl<'a> ModelSelection<'a> { - pub fn new(app_state: &'a mut AppState, models_available: Vec) -> Self { - Self { - app_state, - models_available, - } + pub fn new(app_state: &'a mut AppState) -> Self { + Self { app_state } } } @@ -22,7 +18,7 @@ impl<'a> Widget for ModelSelection<'a> { egui::ComboBox::from_label("") .selected_text("Models") .show_ui(ui, |ui| { - for model in self.models_available { + for model in &self.app_state.models_available { let selected_models = self .app_state .selected_models @@ -31,14 +27,14 @@ impl<'a> Widget for ModelSelection<'a> { .collect::>(); let contained = selected_models.contains(&model); - let label = ui.selectable_label(contained, &model); + let label = ui.selectable_label(contained, model); if label.clicked() && !contained { self.app_state.selected_models.push(ModelResponse { name: model.clone(), message: "No message".to_string(), status: 0, - time: 0, + time: 0.0, }); } } diff --git a/rust-learning/src/gui/prompt_input.rs b/rust-learning/src/gui/prompt_input.rs index c093de3..9a99365 100644 --- a/rust-learning/src/gui/prompt_input.rs +++ b/rust-learning/src/gui/prompt_input.rs @@ -14,16 +14,19 @@ impl<'a> PromptInput<'a> { impl<'a> Widget for PromptInput<'a> { fn ui(self, ui: &mut Ui) -> Response { - ui.horizontal(|ui| { - let label = ui.label("Prompt:"); - ui.text_edit_singleline(&mut self.app_state.prompt_input) - .labelled_by(label.id); + let label = ui.label("Prompt:"); + ui.text_edit_multiline(&mut self.app_state.prompt_input) + .labelled_by(label.id); + ui.horizontal(|ui| { if ui.button("Submit").clicked() { + futures::executor::block_on(self.app_state.handle_submission()); + } + + if ui.button("Clear").clicked() { self.app_state.prompt_input.clear(); } - }); - - ui.label(&self.app_state.prompt_input) + }) + .response } } diff --git a/rust-learning/src/gui/state.rs b/rust-learning/src/gui/state.rs index 205b5ea..8cf4157 100644 --- a/rust-learning/src/gui/state.rs +++ b/rust-learning/src/gui/state.rs @@ -1,17 +1,48 @@ -use eframe::CreationContext; +use eframe::{egui, CreationContext}; -use super::model_response::ModelResponse; +use super::{model_response::ModelResponse, model_selection, prompt_input}; pub struct AppState { pub selected_models: Vec, pub prompt_input: String, + pub models_available: Vec, } impl AppState { - pub fn new(_cc: &CreationContext<'_>) -> Self { + pub fn new(_cc: &CreationContext<'_>, models_available: Vec) -> Self { Self { selected_models: vec![], prompt_input: String::new(), + models_available, } } + + pub async fn handle_submission(&mut self) { + let mut completions = vec![]; + + for model in &mut self.selected_models { + completions.push(model.chat_completion(self.prompt_input.clone())); + } + + futures::future::join_all(completions).await; + } +} + +impl eframe::App for AppState { + fn update(&mut self, ctx: &egui::Context, _frame: &mut eframe::Frame) { + egui::CentralPanel::default().show(ctx, |ui| { + ui.heading("Groq Model Comparison"); + ui.label("Compare Groq models with ease!"); + + ui.add(model_selection::ModelSelection::new(self)); + + ui.add(prompt_input::PromptInput::new(self)); + + ui.horizontal(|ui| { + for model in &self.selected_models { + ui.vertical(|ui| ui.add(model)); + } + }); + }); + } } diff --git a/rust-learning/src/main.rs b/rust-learning/src/main.rs index 8a5ebda..b771a5e 100644 --- a/rust-learning/src/main.rs +++ b/rust-learning/src/main.rs @@ -1,5 +1,12 @@ -pub mod gui; +use groq::GROQ_CLIENT; -fn main() -> eframe::Result { - gui::main() +pub mod groq; +pub mod gui; +pub mod config; + +#[tokio::main] +async fn main() -> eframe::Result { + let models = GROQ_CLIENT.list_models().await; + + gui::main(models) }