From 8b8f48a743f1660043a05f4a0da6012aa4b409cb Mon Sep 17 00:00:00 2001 From: Alex Perez Date: Sun, 13 Oct 2024 05:43:44 +1300 Subject: [PATCH] fix: update gemini's `safety_settings` (#1057) Support gemini's new safety settings --- instructor/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/instructor/utils.py b/instructor/utils.py index eac88d004..8bd012f1b 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -347,12 +347,20 @@ def update_gemini_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: # minimize gemini safety related errors - model is highly prone to false alarms from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore - kwargs["safety_settings"] = kwargs.get("safety_settings", {}) | { + kwargs["safety_settings"] = kwargs.get("safety_settings", {}) + + fallback_safety_settings = { HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, } + # Update or add fallback settings, respecting stricter existing ones + for category, fallback_threshold in fallback_safety_settings.items(): + current_threshold = kwargs["safety_settings"].get(category) + if current_threshold is None or current_threshold < fallback_threshold: + kwargs["safety_settings"][category] = fallback_threshold + return kwargs