-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfinetuningcode.py
55 lines (43 loc) · 2.08 KB
/
finetuningcode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def fine_tune_model(model, tokenizer):
import json
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
# Load feedback data
with open('feedback_data.json', 'r') as f:
data = json.load(f)
# Filter high-quality questions (overall_rating >= 4)
high_quality_data = [item for item in data if item['overall_rating'] >= 4]
class FeedbackDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
input_text = f"Context: {item['context']}\nQuestion: {item['question']}\nAnswer: {item['answer']}"
target_text = item['question']
inputs = tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
targets = tokenizer(target_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
return {
"input_ids": inputs.input_ids.squeeze(),
"attention_mask": inputs.attention_mask.squeeze(),
"labels": targets.input_ids.squeeze()
}
dataset = FeedbackDataset(high_quality_data)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader))
model.train()
for epoch in range(3): # 3 epochs for fine-tuning
for batch in dataloader:
optimizer.zero_grad()
outputs = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
return model
# Call this function periodically or when a certain amount of new feedback is collected
# model = fine_tune_model(model, tokenizer)