Skip to content

Commit

Permalink
Add /llama3 command and fix bugs (#13)
Browse files Browse the repository at this point in the history
* added groq command & openai endpoint support

* Rename OpenAI structs and types & remove test cuz cring

* better error handling

* removed unused parameters

* replaced to string into into

* changed function name and reordered parameters

* added max tokens

* usuń impl From<FormattedText> for CommandError
usuń impl std::fmt::Display for CommandError
zmień kolejność aliasów na ["llama3", "llama", "groq"]
usuń zmienne parts i parts_str
usuń komentarz // Convert text parts into a single string
dodaj parametr max_tokens do funkcji, wyślij go do api i ustaw na 256
domyślna temperatura

* inline some parameters

* changed command ordering and changed filename

* replaced match with response err

* fixed utf 8 panic but untested

* i think i fixed the makersuite.rs:110:80: error

* uh

* o to chodzilo??

* added PDF delay (untested)

* 5 second delay for all websites

* change case

* nitpicks

---------

Co-authored-by: jel <25802745+jelni@users.noreply.github.com>
  • Loading branch information
DuckyBlender and jelni authored May 27, 2024
1 parent 87fa058 commit 4bc5ee9
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 34 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ DB_ENCRYPTION_KEY=12345678
STABLEHORDE_TOKEN=0000000000
STABLEHORDE_CLIENT=name:version:contact
MAKERSUITE_API_KEY=YOUR_API_KEY
GROQ_API_KEY=YOUR_API_KEY
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ services:
STABLEHORDE_TOKEN: ${STABLEHORDE_TOKEN}
STABLEHORDE_CLIENT: ${STABLEHORDE_CLIENT}
MAKERSUITE_API_KEY: ${MAKERSUITE_API_KEY}
GROQ_API_KEY: ${GROQ_API_KEY}
volumes:
- craiyon-bot:/app

Expand Down
1 change: 1 addition & 0 deletions src/apis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod makersuite;
pub mod mathjs;
pub mod microlink;
pub mod moveit;
pub mod openai;
pub mod poligon;
pub mod stablehorde;
pub mod translate;
Expand Down
17 changes: 9 additions & 8 deletions src/apis/makersuite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ pub async fn stream_generate_content(
return;
}

let mut buffer = String::new();
let mut buffer = Vec::new();
let mut stream = response.bytes_stream();

while let Some(part) = stream.next().await {
Expand All @@ -187,23 +187,24 @@ pub async fn stream_generate_content(
}
};

buffer.push_str(&String::from_utf8(part.to_vec()).unwrap());
buffer.extend(&part);

if let Some(stripped) = buffer.strip_prefix('[') {
if let Some(stripped) = buffer.strip_prefix(b"[") {
buffer = stripped.into();
}

if let Some(stripped) = buffer.strip_suffix("\n]") {
if let Some(stripped) = buffer.strip_suffix(b"\n]") {
buffer = stripped.into();
}

while let Some((first, rest)) = buffer.split_once("\n,\r\n") {
tx.send(Ok(serde_json::from_str(first).unwrap())).unwrap();
buffer = rest.into();
while let Some(index) = buffer.windows(4).position(|window| window == b"\n,\r\n") {
let (first, rest) = buffer.split_at(index);
tx.send(Ok(serde_json::from_str(&String::from_utf8_lossy(first)).unwrap())).unwrap();
buffer = rest[4..].into();
}
}

tx.send(Ok(serde_json::from_str(&buffer).unwrap())).unwrap();
tx.send(Ok(serde_json::from_str(&String::from_utf8_lossy(&buffer)).unwrap())).unwrap();
}

#[derive(Serialize)]
Expand Down
2 changes: 1 addition & 1 deletion src/apis/microlink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ pub async fn screenshot(
"https://api.microlink.io/",
[
("url", url.as_str()),
("adblock", "false"),
("color_scheme", "dark"),
("ping", "false"),
("prerender", "true"),
("screenshot", "true"),
("timeout", "1m"),
("viewport.width", "1280"),
("viewport.height", "640"),
("wait_for_timeout", "5s"),
("wait_until", "load"),
],
)
Expand Down
69 changes: 69 additions & 0 deletions src/apis/openai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};

use crate::commands::CommandError;
use crate::utilities::api_utils::DetectServerError;

#[derive(Serialize)]
struct Request<'a> {
model: &'static str,
messages: &'a [Message<'a>],
max_tokens: u16,
}

#[derive(Serialize)]
pub struct Message<'a> {
pub role: &'static str,
pub content: &'a str,
}

#[derive(Deserialize)]
pub struct ChatCompletion {
pub choices: Vec<Choice>,
}

#[derive(Deserialize)]
pub struct Choice {
pub message: MessageResponse,
pub finish_reason: String,
}

#[derive(Deserialize)]
pub struct MessageResponse {
pub content: String,
}

#[derive(Deserialize)]
pub struct ErrorResponse {
pub error: Error,
}

#[derive(Deserialize)]
pub struct Error {
pub code: String,
pub message: String,
}

pub async fn chat_completion(
http_client: reqwest::Client,
base_url: &str,
api_key: &str,
model: &'static str,
messages: &[Message<'_>],
) -> Result<Result<ChatCompletion, Error>, CommandError> {
let response = http_client
.post(format!("{base_url}/chat/completions"))
.bearer_auth(api_key)
.json(&Request { model, messages, max_tokens: 256 })
.send()
.await?
.server_error()?;

if response.status() == StatusCode::OK {
let response = response.json::<ChatCompletion>().await?;
Ok(Ok(response))
} else {
let response = response.json::<ErrorResponse>().await?;
Ok(Err(response.error))
}
}
1 change: 1 addition & 0 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod craiyon_search;
pub mod delete;
pub mod dice_reply;
pub mod different_dimension_me;
pub mod groq;
pub mod kebab;
pub mod kiwifarms;
pub mod makersuite;
Expand Down
60 changes: 60 additions & 0 deletions src/commands/groq.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use std::env;
use std::fmt::Write;

use async_trait::async_trait;
use tdlib::types::FormattedText;
use tdlib::{enums, functions};

use super::{CommandError, CommandResult, CommandTrait};
use crate::apis::openai::{self, Message};
use crate::utilities::command_context::CommandContext;
use crate::utilities::convert_argument::{ConvertArgument, StringGreedyOrReply};
use crate::utilities::rate_limit::RateLimiter;

pub struct Llama;

#[async_trait]
impl CommandTrait for Llama {
fn command_names(&self) -> &[&str] {
&["llama3", "llama", "groq"]
}

fn description(&self) -> Option<&'static str> {
Some("ask Llama 3 70B")
}

fn rate_limit(&self) -> RateLimiter<i64> {
RateLimiter::new(3, 60)
}

async fn execute(&self, ctx: &CommandContext, arguments: String) -> CommandResult {
let StringGreedyOrReply(prompt) = StringGreedyOrReply::convert(ctx, &arguments).await?.0;

ctx.send_typing().await?;

let response = openai::chat_completion(
ctx.bot_state.http_client.clone(),
"https://api.groq.com/openai/v1",
&env::var("GROQ_API_KEY").unwrap(),
"llama3-70b-8192",
&[Message { role: "user", content: &prompt }],
)
.await?
.map_err(|err| CommandError::Custom(format!("error {}: {}", err.code, err.message)))?;

let choice = response.choices.into_iter().next().unwrap();
let mut text = choice.message.content;

if choice.finish_reason != "STOP" {
write!(text, " [{}]", choice.finish_reason).unwrap();
}

let enums::FormattedText::FormattedText(formatted_text) =
functions::parse_markdown(FormattedText { text, ..Default::default() }, ctx.client_id)
.await?;

ctx.reply_formatted_text(formatted_text).await?;

Ok(())
}
}
53 changes: 28 additions & 25 deletions src/commands/makersuite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,34 @@ impl CommandTrait for GoogleGemini {
let mut message = Option::<Message>::None;

loop {
let (update_message, finished) =
if let Ok(response) = tokio::time::timeout_at(next_update, rx.recv()).await {
match response {
Some(response) => {
let response = response?;

match progress.as_mut() {
Some(progress) => {
progress.update(response)?;
}
None => {
progress = Some(GenerationProgress::new(
response.candidates.into_iter().next().unwrap(),
));
let (update_message, finished) = if let Ok(response) =
tokio::time::timeout_at(next_update, rx.recv()).await
{
match response {
Some(response) => {
let response = response?;

match progress.as_mut() {
Some(progress) => {
progress.update(response)?;
changed_after_last_update = true;
}
None => {
if let Some(candidate) = response.candidates.into_iter().next() {
progress = Some(GenerationProgress::new(candidate));
changed_after_last_update = true;
}
}

changed_after_last_update = true;
(false, false)
}
None => (true, true),

(false, false)
}
} else {
next_update = Instant::now() + Duration::from_secs(5);
(true, false)
};
None => (true, true),
}
} else {
next_update = Instant::now() + Duration::from_secs(5);
(true, false)
};

if update_message && changed_after_last_update {
let text = match progress.as_ref() {
Expand Down Expand Up @@ -284,9 +286,10 @@ impl GenerationProgress {
if let Some(content) = candidate.content {
self.parts.extend(content.parts);

if let Some(citation_metadata) = candidate.citation_metadata {
self.citation_sources.extend(citation_metadata.citation_sources);
}
self.citation_sources = candidate
.citation_metadata
.map(|citation_metadata| citation_metadata.citation_sources)
.unwrap_or_default();
}

self.finish_reason = candidate.finish_reason;
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async fn main() {
bot.add_command(commands::different_dimension_me::DifferentDimensionMe);
bot.add_command(commands::makersuite::GoogleGemini);
bot.add_command(commands::makersuite::GooglePalm);
bot.add_command(commands::groq::Llama);
bot.add_command(commands::translate::Translate);
bot.add_command(commands::badtranslate::BadTranslate);
bot.add_command(commands::trollslate::Trollslate);
Expand Down

0 comments on commit 4bc5ee9

Please sign in to comment.