Skip to content

Commit

Permalink
adding post processing data
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackson Barbosa committed Dec 7, 2023
1 parent b1bf5a5 commit 97f132a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
19 changes: 19 additions & 0 deletions bothub/api/v2/zeroshot/usecases/format_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .format_prompt import FormatPrompt

class FormatClassification:
def __init__(self, classification_data):
self.classification_data = classification_data

def get_classify(self, options, language):
classification = self.classification_data.get("text")[0].strip().lower()
formatter = FormatPrompt()
classify = {"other": False, "classification": ""}
for option in options:
option_class = option.get("class").strip().lower()
if classification == option_class:
classify["classification"] = option.get("class")
break
if len(classify["classification"]) == 0:
classify["classification"] = formatter.get_none_class(language=language)
classify["other"] = True
return classify
28 changes: 11 additions & 17 deletions bothub/api/v2/zeroshot/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

from .usecases.format_prompt import FormatPrompt
from .usecases.format_classification import FormatClassification

from bothub.api.v2.zeroshot.permissions import ZeroshotTokenPermission

Expand Down Expand Up @@ -85,8 +86,8 @@ class ZeroShotFastPredictAPIView(APIView):
def post(self, request):

data = request.data
formatter = FormatPrompt()
prompt = formatter.generate_prompt(data.get("language"), data)
prompt_formatter = FormatPrompt()
prompt = prompt_formatter.generate_prompt(data.get("language"), data)

body = {
"input": {
Expand Down Expand Up @@ -121,25 +122,18 @@ def post(self, request):
json=body
)

response = {}
other = False
classification = None

response = {"output": {}}
if response_nlp.status_code == 200:
classification_data = response_nlp.json().get("output")
classification = classification_data.get("text")[0].strip()
other = formatter.get_none_class(language=data.get("language")) in classification
response = {
"output": {
"classification": classification,
"other": other
}
}
formatted_classification = FormatClassification(
response_nlp.json().get("output")
).get_classify(language=data.get("language"), options=data.get("options"))

response["output"] = formatted_classification

ZeroshotLogs.objects.create(
text=data.get("text"),
classification=classification,
other=other,
classification=response["output"].get("classification"),
other=response["output"].get("other", False),
options=data.get("options"),
nlp_log=str(response_nlp.json()),
language=data.get("language")
Expand Down

0 comments on commit 97f132a

Please sign in to comment.