-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.gin
73 lines (67 loc) · 2.5 KB
/
config.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import trax.layers
import trax.models
import trax.optimizers
import trax.data.inputs
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM
# Our model will have 6 layers, alternating between the LSH attention proposed
# in the Reformer paper and local attention within a certain context window.
n_layers = 4
n_heads = 8
attn_type = [
@trax.layers.SelfAttention,
@trax.layers.SelfAttention,
@LSHSelfAttention,
@trax.layers.SelfAttention,
]
share_qk = False # LSH attention ignores this flag and always shares q & k
vocab_size = 32768
attn_kv = 128
dropout = 0.2
n_tokens = 2048
# Parameters for multifactor:
# ==============================================================================
multifactor.constant = 0.03125
multifactor.warmup_steps = 2000
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-9
# Parameters for SelfAttention:
# ==============================================================================
trax.layers.SelfAttention.attention_dropout = %dropout
trax.layers.SelfAttention.n_chunks_after = 0
trax.layers.SelfAttention.chunk_len = 256
trax.layers.SelfAttention.n_chunks_before = 1
# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.0
LSHSelfAttention.chunk_len = 256
LSHSelfAttention.n_buckets = 512
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 4
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 256
LSHSelfAttention.predict_mem_len = 1024
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.vocab_size = %vocab_size
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 1024
ReformerLM.d_ff = 4096
ReformerLM.dropout = %dropout
ReformerLM.ff_activation = @trax.layers.Relu
ReformerLM.max_len = %n_tokens
ReformerLM.mode = 'train'
ReformerLM.n_heads = %n_heads
ReformerLM.n_layers = %n_layers
ReformerLM.axial_pos_shape = (32, 128)
ReformerLM.d_axial_pos_embs= (256, 768)