Skip to content

Commit

Permalink
allow missing safetyRatings
Browse files Browse the repository at this point in the history
  • Loading branch information
jelni committed Dec 27, 2023
1 parent 6e68537 commit bdd0e9e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 71 deletions.
94 changes: 42 additions & 52 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions src/apis/makersuite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct GenerationConfig {
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentResponse {
#[serde(default)]
pub candidates: Vec<Candidate>,
pub prompt_feedback: Option<PromptFeedback>,
}
Expand Down Expand Up @@ -96,13 +97,14 @@ pub struct CitationSource {
#[serde(rename_all = "camelCase")]
pub struct PromptFeedback {
pub block_reason: Option<String>,
pub safety_ratings: Vec<SafetyRating>,
pub safety_ratings: Option<Vec<SafetyRating>>,
}

#[derive(Debug, Deserialize)]
pub struct SafetyRating {
pub category: String,
pub blocked: Option<bool>,
#[serde(default)]
pub blocked: bool,
}

#[derive(Deserialize)]
Expand Down
29 changes: 12 additions & 17 deletions src/commands/makersuite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,18 @@ impl CommandTrait for GoogleGemini {
};

if let Some(prompt_feedback) = response.prompt_feedback {
if let Some(block_reason) = prompt_feedback.block_reason {
if block_reason == "SAFETY" {
let reasons = prompt_feedback
.safety_ratings
.into_iter()
.filter(|safety_rating| safety_rating.blocked.unwrap_or_default())
.map(|safety_rating| safety_rating.category)
.collect::<Vec<_>>()
.join(", ");

return Err(CommandError::Custom(format!(
"request blocked by Google: {reasons}."
)));
};

return Err(CommandError::Custom("request blocked by Google.".into()));
};
if let Some(safety_ratings) = prompt_feedback.safety_ratings {
let reasons = safety_ratings
.into_iter()
.filter(|safety_rating| safety_rating.blocked)
.map(|safety_rating| safety_rating.category)
.collect::<Vec<_>>()
.join(", ");

return Err(CommandError::Custom(format!("request blocked by Google: {reasons}.")));
}

return Err(CommandError::Custom("request blocked by Google.".into()));
}

let Some(candidate) = response.candidates.into_iter().next() else {
Expand Down

0 comments on commit bdd0e9e

Please sign in to comment.