-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
137 lines (110 loc) · 3.51 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import fastai as fa
import fastai.text as fatext
import numpy as np
import sentencepiece as sp
import torch
import torch.nn.functional as F
import argparse
import sys
import readline
import re
from typing import List, Tuple, Any
def loadVocab(filename):
with open(filename, "r") as f:
tokens = [l.strip().split()[0] for l in f]
return fatext.transform.Vocab(tokens)
class Models:
def __init__(self, vocab, learners: List[Tuple[float, Any]]):
self.vocab = vocab
self.learners = learners
self.n = None
self.end = "▁xx"
self.allowed = None
self.temperature = 0.7
self.repetition_penalty = 0.7
self.excluded_tokens = ["<unk>"]
self.promoted_tokens = []
def weightedPredict(self, tokens: List[str]):
for _, learner in self.learners:
learner.model.reset()
if self.allowed:
allow = torch.zeros([len(self.vocab.stoi)])
for i, s in enumerate(self.vocab.stoi):
if self.allowed.search(s):
allow[i] = 1.
else:
allow = torch.ones([len(self.vocab.stoi)])
xb = torch.tensor([self.vocab.numericalize(tokens or [""])])
history = []
i_x = 0
while True:
res = sum([w*learner.pred_batch(batch=(xb,torch.tensor([0])))[0][-1] for w, learner in self.learners])
for token in self.excluded_tokens:
res[self.vocab.stoi[token]] = 0.
for token in self.promoted_tokens:
res[self.vocab.stoi[token]] *= 10.
if self.repetition_penalty > 0.:
for i, token_id in enumerate(reversed(history)):
res[token_id] *= 1.0-self.repetition_penalty*2**(-i*.1)
res.mul_(allow)
if self.temperature != 1.:
res.pow_(1 / self.temperature)
idx = torch.multinomial(res, 1).item()
tok = self.vocab.itos[idx]
if tok == self.end:
break
yield tok
history.append(idx)
xb = xb.new_tensor([idx])[None]
i_x += 1
if i_x == self.n:
break
def model(s):
try:
t, w, model = s.split(',')
return t, float(w), model
except:
raise argparse.ArgumentTypeError("Models must be type,weight,model")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("vocab_prefix")
parser.add_argument("models", nargs="+", type=model)
args = parser.parse_args()
vocab = loadVocab(args.vocab_prefix + ".vocab")
spm = sp.SentencePieceProcessor()
spm.Load(args.vocab_prefix + ".model")
learners = []
for t, w, model_name in args.models:
db = fatext.data.TextLMDataBunch.from_ids(".", vocab, np.array([[0]]), np.array([[0]]))
learner = fatext.learner.language_model_learner(db, fatext.models.AWD_LSTM if t != "txl" else fatext.models.TransformerXL, pretrained=False)
learner.load(model_name)
learners.append((w, learner))
models = Models(vocab, learners)
while True:
try:
text = input("> ").lower()
except EOFError:
break
if text.startswith("/n "):
models.n = int(text.split(" ")[1])
models.end = None
elif text.startswith("/temp "):
models.temperature = float(text.split(" ")[1])
elif text.startswith("/repe "):
models.repetition_penalty = float(text.split(" ")[1])
elif text.startswith("/end "):
models.end = text.split(" ")[1].replace("_", "▁")
models.n = None
elif text.startswith("/allow "):
models.allowed = re.compile(text[text.index(" ")+1:])
else:
tokens = spm.EncodeAsPieces(text)
for i, token in enumerate(models.weightedPredict(tokens)):
if token == "▁br":
print("")
else:
print("\x1b[" + ("0m" if i%2 == 0 else "4m") + token.replace("▁", " "), end="")
sys.stdout.flush()
print("\x1b[0m")
if __name__ == "__main__":
main()