-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
43 lines (34 loc) · 1.32 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model_ckpt ="distilbert/distilbert-base-uncased"
model_name =f"{model_ckpt}-finetuned-emotions-shixm"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def predict(text, model, tokenizer):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
return predictions
from flask import Flask, request, jsonify
import torch
app = Flask(__name__)
def inferExec(inputs):
map_label = {0: '悲伤', 1: '喜悦', 2: '爱', 3: '愤怒', 4: '恐惧', 5: '惊喜'}
resu = predict(inputs, model, tokenizer)
max = torch.argmax(resu)
print(f"resu:{resu},max:{max}")
return map_label[max.item()]
@app.route('/predict', methods=['POST'])
def predict_text():
data = request.get_json()
text = data.get('text')
if not text:
return jsonify({'error': 'No text provided'}), 400
# predictions = predict(text, model, tokenizer)
ret = inferExec(text)
return jsonify(ret)
if __name__ == '__main__':
print(inferExec("fuck!@@@"))
print(inferExec("iamhappy!@@@"))
app.run(debug=True)