-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add /llama3 command and fix bugs (#13)
* added groq command & openai endpoint support * Rename OpenAI structs and types & remove test cuz cring * better error handling * removed unused parameters * replaced to string into into * changed function name and reordered parameters * added max tokens * usuń impl From<FormattedText> for CommandError usuń impl std::fmt::Display for CommandError zmień kolejność aliasów na ["llama3", "llama", "groq"] usuń zmienne parts i parts_str usuń komentarz // Convert text parts into a single string dodaj parametr max_tokens do funkcji, wyślij go do api i ustaw na 256 domyślna temperatura * inline some parameters * changed command ordering and changed filename * replaced match with response err * fixed utf 8 panic but untested * i think i fixed the makersuite.rs:110:80: error * uh * o to chodzilo?? * added PDF delay (untested) * 5 second delay for all websites * change case * nitpicks --------- Co-authored-by: jel <25802745+jelni@users.noreply.github.com>
- Loading branch information
1 parent
87fa058
commit 4bc5ee9
Showing
10 changed files
with
172 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
use reqwest::StatusCode; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::commands::CommandError; | ||
use crate::utilities::api_utils::DetectServerError; | ||
|
||
#[derive(Serialize)] | ||
struct Request<'a> { | ||
model: &'static str, | ||
messages: &'a [Message<'a>], | ||
max_tokens: u16, | ||
} | ||
|
||
#[derive(Serialize)] | ||
pub struct Message<'a> { | ||
pub role: &'static str, | ||
pub content: &'a str, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct ChatCompletion { | ||
pub choices: Vec<Choice>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct Choice { | ||
pub message: MessageResponse, | ||
pub finish_reason: String, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct MessageResponse { | ||
pub content: String, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct ErrorResponse { | ||
pub error: Error, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct Error { | ||
pub code: String, | ||
pub message: String, | ||
} | ||
|
||
pub async fn chat_completion( | ||
http_client: reqwest::Client, | ||
base_url: &str, | ||
api_key: &str, | ||
model: &'static str, | ||
messages: &[Message<'_>], | ||
) -> Result<Result<ChatCompletion, Error>, CommandError> { | ||
let response = http_client | ||
.post(format!("{base_url}/chat/completions")) | ||
.bearer_auth(api_key) | ||
.json(&Request { model, messages, max_tokens: 256 }) | ||
.send() | ||
.await? | ||
.server_error()?; | ||
|
||
if response.status() == StatusCode::OK { | ||
let response = response.json::<ChatCompletion>().await?; | ||
Ok(Ok(response)) | ||
} else { | ||
let response = response.json::<ErrorResponse>().await?; | ||
Ok(Err(response.error)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
use std::env; | ||
use std::fmt::Write; | ||
|
||
use async_trait::async_trait; | ||
use tdlib::types::FormattedText; | ||
use tdlib::{enums, functions}; | ||
|
||
use super::{CommandError, CommandResult, CommandTrait}; | ||
use crate::apis::openai::{self, Message}; | ||
use crate::utilities::command_context::CommandContext; | ||
use crate::utilities::convert_argument::{ConvertArgument, StringGreedyOrReply}; | ||
use crate::utilities::rate_limit::RateLimiter; | ||
|
||
pub struct Llama; | ||
|
||
#[async_trait] | ||
impl CommandTrait for Llama { | ||
fn command_names(&self) -> &[&str] { | ||
&["llama3", "llama", "groq"] | ||
} | ||
|
||
fn description(&self) -> Option<&'static str> { | ||
Some("ask Llama 3 70B") | ||
} | ||
|
||
fn rate_limit(&self) -> RateLimiter<i64> { | ||
RateLimiter::new(3, 60) | ||
} | ||
|
||
async fn execute(&self, ctx: &CommandContext, arguments: String) -> CommandResult { | ||
let StringGreedyOrReply(prompt) = StringGreedyOrReply::convert(ctx, &arguments).await?.0; | ||
|
||
ctx.send_typing().await?; | ||
|
||
let response = openai::chat_completion( | ||
ctx.bot_state.http_client.clone(), | ||
"https://api.groq.com/openai/v1", | ||
&env::var("GROQ_API_KEY").unwrap(), | ||
"llama3-70b-8192", | ||
&[Message { role: "user", content: &prompt }], | ||
) | ||
.await? | ||
.map_err(|err| CommandError::Custom(format!("error {}: {}", err.code, err.message)))?; | ||
|
||
let choice = response.choices.into_iter().next().unwrap(); | ||
let mut text = choice.message.content; | ||
|
||
if choice.finish_reason != "STOP" { | ||
write!(text, " [{}]", choice.finish_reason).unwrap(); | ||
} | ||
|
||
let enums::FormattedText::FormattedText(formatted_text) = | ||
functions::parse_markdown(FormattedText { text, ..Default::default() }, ctx.client_id) | ||
.await?; | ||
|
||
ctx.reply_formatted_text(formatted_text).await?; | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters