-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
193 lines (157 loc) · 5.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-t",
"--multi_task_names",
type=str,
nargs="+",
required=True,
help="which datasets to use")
parser.add_argument("-m",
"--multi_task_type",
type=str,
choices=["mix", "continual"],
help="identify if this is a multi-task baseline and whether the type of the multi-task baseline is")
parser.add_argument("--check_point",
type=str
)
parser.add_argument("--continue_training",
action="store_true"
)
args = parser.parse_args()
import os
import torch
import torch.nn as nn
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from peft import get_peft_model, PeftModel
import re
import string
import json
import random
from datasets import load_dataset, Dataset
from tqdm import tqdm
from trl import SFTTrainer
from useful_functions import *
from eval_utils import *
from config import *
from prompt import prompt_dict
from postprocessors import postprocessor_dict
from load_model_and_tokenizer import load_model_and_tokenizer
os.environ["CUDA_VISIBLE_DEVICES"]="0"
""" ====== CHECK INPUT ARGUMENTS ====== """
print("Check Input Arguments ...")
if args.multi_task_type == None:
if len(args.multi_task_names) > 1:
raise ValueError("Expected multi-task type (mix or continual) when having multiple task names")
""" ====== PATH ====== """
print("Set Path Constants ...")
(
TASK_DATA_PATHS,
FILTERED_DATASET_PATH,
OUTPUT_DIR,
TRAIN_SET_PATH,
TEST_SET_PATH,
VAL_SET_PATH,
ADAPTER_CONFIG_PATH,
EVAL_BEF_PATH,
RESULT_PATH
) = set_path_constants(args.multi_task_names, args.multi_task_type)
""" ====== Load Model and Tokenizer ====== """
print("Load Model and Tokenizer ...")
model_path = OUTPUT_DIR if args.continue_training else None
model, tokenizer = load_model_and_tokenizer(model_path)
""" ====== Tokenizer Demo ====== """
print("Tokenizer Demo ...")
""" ====== Load Dataset ====== """
print("Load Dataset and Tokenized ...")
with open(TRAIN_SET_PATH, "r", encoding = "utf-8") as f:
train_instances = json.load(f)
with open(VAL_SET_PATH, "r", encoding = "utf-8") as f:
val_instances = json.load(f)
train_data = [generate_training_data(tokenizer, data) for data in train_instances]
val_data = [generate_training_data(tokenizer, data) for data in val_instances[:0]]
print("train_instances sample", train_data[0])
check_input_ids = train_data[0]["input_ids"]
check_tokens = [tokenizer.convert_ids_to_tokens(_id_) for _id_ in check_input_ids]
print("tokens splitted", check_tokens)
from itertools import takewhile
print("user prompt token length", len([label for label in takewhile(lambda x: x == -100, train_data[0]["labels"])]))
print("id length", len([_id for _id in train_data[0]["input_ids"] if _id != 2]) + 1)
print("attention mask length", len([label for label in train_data[0]["attention_mask"] if label != 0]))
""" ====== Training tokens Calculate ====== """
print("Training tokens Calculate ...")
token_count = 0
for data in train_data:
for i in data["input_ids"]:
if i < 0:
break
token_count += 1
print("total fine tuning tokens number", token_count)
""" ====== Add Adapter layer (LoRA) ====== """
print("Add Adapter layer (LoRA) ...")
model = get_peft_model(model, lora_config)
print_trainable_parameters(model)
print(model.targeted_module_names)
""" ====== NaN solution ====== """
'''
print("NaN solution ...")
'''
""" ====== Training ====== """
print("Training ...")
warmup_steps = 0
num_train_epochs = 3
learning_rate = 1e-4
max_steps = 375
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=warmup_steps,
num_train_epochs=num_train_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=20,
output_dir=OUTPUT_DIR,
optim="paged_adamw_32bit",
# evaluation_strategy="steps",
# eval_steps=3000,
save_steps=500,
seed=RANDOM_SEED,
# max_steps=max_steps
)
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_args,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
if args.check_point:
trainer.train(args.check_point)
else:
trainer.train()
model.save_pretrained(OUTPUT_DIR)
""" ====== Record Results ====== """
print("Record Results ...")
with open(ADAPTER_CONFIG_PATH, "r", encoding = "utf-8") as f:
adapter_config = json.load(f)
result = adapter_config
result["task_names"] = str(args.multi_task_names)
result["TEST_SET_RATIO"] = TEST_SET_RATIO
result["VAL_SET_RATIO"] = VAL_SET_RATIO
result["RANDOM_SEED"] = RANDOM_SEED
result["warmup_steps"] = warmup_steps
result["num_train_epochs"] = num_train_epochs
result["max_steps"] = max_steps
result["learning_rate"] = learning_rate
result["trainable_parameters"] = print_trainable_parameters(model)
result["hyperparameters"] = hyperparameters
result["log_history"] = trainer.state.log_history
with open(RESULT_PATH, "w", encoding = "utf-8") as f:
json.dump(result, f, indent = 2, ensure_ascii = False)
print("Training Stage Finished!")