diff --git a/bothub/api/v2/zeroshot/usecases/format_classification.py b/bothub/api/v2/zeroshot/usecases/format_classification.py new file mode 100644 index 00000000..90c7c45d --- /dev/null +++ b/bothub/api/v2/zeroshot/usecases/format_classification.py @@ -0,0 +1,19 @@ +from .format_prompt import FormatPrompt + +class FormatClassification: + def __init__(self, classification_data): + self.classification_data = classification_data + + def get_classify(self, options, language): + classification = self.classification_data.get("text")[0].strip().lower() + formatter = FormatPrompt() + classify = {"other": False, "classification": ""} + for option in options: + option_class = option.get("class").strip().lower() + if classification == option_class: + classify["classification"] = option.get("class") + break + if len(classify["classification"]) == 0: + classify["classification"] = formatter.get_none_class(language=language) + classify["other"] = True + return classify diff --git a/bothub/api/v2/zeroshot/views.py b/bothub/api/v2/zeroshot/views.py index 154b06c4..0afa2c33 100644 --- a/bothub/api/v2/zeroshot/views.py +++ b/bothub/api/v2/zeroshot/views.py @@ -17,6 +17,7 @@ ) from .usecases.format_prompt import FormatPrompt +from .usecases.format_classification import FormatClassification from bothub.api.v2.zeroshot.permissions import ZeroshotTokenPermission @@ -85,8 +86,8 @@ class ZeroShotFastPredictAPIView(APIView): def post(self, request): data = request.data - formatter = FormatPrompt() - prompt = formatter.generate_prompt(data.get("language"), data) + prompt_formatter = FormatPrompt() + prompt = prompt_formatter.generate_prompt(data.get("language"), data) body = { "input": { @@ -121,25 +122,18 @@ def post(self, request): json=body ) - response = {} - other = False - classification = None - + response = {"output": {}} if response_nlp.status_code == 200: - classification_data = response_nlp.json().get("output") - classification = classification_data.get("text")[0].strip() - other = formatter.get_none_class(language=data.get("language")) in classification - response = { - "output": { - "classification": classification, - "other": other - } - } + formatted_classification = FormatClassification( + response_nlp.json().get("output") + ).get_classify(language=data.get("language"), options=data.get("options")) + + response["output"] = formatted_classification ZeroshotLogs.objects.create( text=data.get("text"), - classification=classification, - other=other, + classification=response["output"].get("classification"), + other=response["output"].get("other", False), options=data.get("options"), nlp_log=str(response_nlp.json()), language=data.get("language")