-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_training_data.py
80 lines (59 loc) · 2.19 KB
/
generate_training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""dimension annotation
n: step
b: batch
t: token position
d: gpt d_model
v: gpt vocab size
l: SAE n latent
k: topk
"""
import argparse
from pathlib import Path
from functools import partial
import torch
import numpy as np
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from tqdm import tqdm
from openwebtext import load_owt, sample
torch.set_grad_enabled(False)
torch.set_float32_matmul_precision("high")
seq_len = 64 # default value of all experiments per paper
d_model = 768 # gpt2 small
seed = 64
rng = np.random.default_rng(seed=64) # explicit control of random seed
data_dir = Path("data")
data_dir.mkdir(parents=True, exist_ok=True)
def hook_fn_save_act(act_btd, hook, step, mmap_act_nbd):
act_bd = act_btd[:, -1].detach().cpu().numpy()
act_bd -= act_bd.mean(axis=-1, keepdims=True)
act_bd /= np.linalg.norm(act_bd, axis=-1, keepdims=True)
mmap_act_nbd[step] = act_bd
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--target_layer", type=int, default=8)
parser.add_argument("--n_step", type=int, default=10_000)
parser.add_argument("--batch_size", type=int, default=2048)
args = parser.parse_args()
ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
output_path = data_dir / f"act_nbd_layer_{args.target_layer}_n_{args.n_step}_bs_{args.batch_size}.bin"
mmap_act_nbd = np.memmap(
str(output_path),
dtype=np.float32,
mode='w+',
shape=(args.n_step, args.batch_size, d_model)
)
tpb = args.batch_size * seq_len
total = args.n_step * tpb
print(f"start data gereration: {args.n_step=:,}, {tpb=:,}, {total=:,}")
for i in tqdm(range(args.n_step), desc="generating data", unit="step"):
hook_fn = partial(hook_fn_save_act, step=i, mmap_act_nbd=mmap_act_nbd)
batch_bt = sample(ds, args.batch_size, rng=rng)
gpt2.run_with_hooks(
batch_bt,
return_type=None,
fwd_hooks=[(utils.get_act_name("resid_post", layer=args.target_layer), hook_fn)],
)
mmap_act_nbd.flush()
print(f"data saved to {output_path}")