-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper.py
33 lines (27 loc) · 1.16 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
def text_to_token_ids(text, tokenizer):
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
return encoded_tensor
def token_ids_to_text(token_ids, tokenizer):
flat = token_ids.squeeze(0)
return tokenizer.decode(flat.tolist())
def generate_text_simple(model, idx, max_new_tokens, context_size):
for _ in range(max_new_tokens):
idx_cond = idx[:, -context_size:]
with torch.no_grad():
logits = model(idx_cond)
logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
idx_next = torch.argmax(probs, dim=-1, keepdim=True)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def generate_and_print_sample(model, tokenizer, device, start_context):
model.eval()
context_size = model.pos_emb.weight.shape[0]
encoded = text_to_token_ids(start_context, tokenizer).to(device)
with torch.no_grad():
token_ids = generate_text_simple(model, encoded, 50, context_size=context_size)
decoded_text = token_ids_to_text(token_ids, tokenizer)
print(decoded_text.replace("\n", ""))
model.train()