-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy patharguments.py
136 lines (118 loc) · 5.29 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import TrainingArguments
@dataclass
class BaseTrainingArguments:
experiment_prefix: str = field(
metadata={"help": "A unique 'name' of this experiment, used to store metadata on the DHT"}
)
initial_peers: List[str] = field(
default_factory=list,
metadata={
"help": "Multiaddrs of the peers that will welcome you into the existing collaboration. "
"Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/udp/7777/quic/p2p/YYYY"
},
)
use_ipfs: bool = field(
default=False,
metadata={
"help": "Use IPFS to find initial_peers. If enabled, you only need to provide /p2p/XXXX part of multiaddrs "
"for the initial_peers (no need to specify a particular IPv4/IPv6 address and port)"
},
)
host_maddrs: List[str] = field(
default_factory=lambda: ["/ip4/0.0.0.0/tcp/0"],
metadata={
"help": "Multiaddrs to listen for external connections from other p2p instances. "
"Defaults to all IPv4 interfaces with TCP protocol: /ip4/0.0.0.0/tcp/0"
},
)
announce_maddrs: List[str] = field(
default_factory=list,
metadata={"help": "Visible multiaddrs the host announces for external connections from other p2p instances"},
)
@dataclass
class AveragerArguments:
averaging_expiration: float = field(
default=30.0, metadata={"help": "Averaging group will wait for stragglers for at most this many seconds"}
)
averaging_timeout: float = field(
default=120.0, metadata={"help": "Give up on averaging step after this many seconds"}
)
min_refresh_period: float = field(
default=0.5, metadata={"help": "Wait for at least this many seconds before fetching new collaboration state"}
)
max_refresh_period: float = field(
default=30, metadata={"help": "Wait for at most this many seconds before fetching new collaboration state"}
)
default_refresh_period: float = field(
default=3, metadata={"help": "Attempt to fetch collaboration state every this often until successful"}
)
expected_drift_peers: float = field(
default=3, metadata={"help": "Trainer assumes that this many new peers can join per step"}
)
expected_drift_rate: float = field(
default=0.2, metadata={"help": "Trainer assumes that this fraction of current size can join per step"}
)
performance_ema_alpha: float = field(
default=0.1, metadata={"help": "Uses this alpha for moving average estimate of samples per second"}
)
target_group_size: int = field(default=256, metadata={"help": "Maximum group size for all-reduce"})
metadata_expiration: float = field(
default=120, metadata={"help": "Peer's metadata will be removed if not updated in this many seconds"}
)
@dataclass
class CollaborativeOptimizerArguments:
target_batch_size: int = field(
default=16384,
metadata={"help": "Perform optimizer step after all peers collectively accumulate this many samples"},
)
client_mode: bool = field(
default=False,
metadata={"help": "Of True, runs training without incoming connections, in a firewall-compatible mode"},
)
batch_size_lead: int = field(
default=0,
metadata={"help": "Optional: begin looking for group in advance, this many samples before target_batch_size"},
)
bandwidth: float = field(
default=100.0,
metadata={"help": "Available network bandwidth, in mbps (used for load balancing in all-reduce)"},
)
@dataclass
class CollaborationArguments(CollaborativeOptimizerArguments, BaseTrainingArguments):
statistics_expiration: float = field(
default=120, metadata={"help": "Statistics will be removed if not updated in this many seconds"}
)
backup_every_steps: int = field(
default=10, metadata={"help": "In case of NaN, fallback to a backup saved once in this many global steps"}
)
@dataclass
class DatasetArguments:
tokenizer_path: Optional[str] = field(default="tokenizer/tokenizer", metadata={"help": "Path to the tokenizer"})
config_path: Optional[str] = field(default="model.json", metadata={"help": "Path to the model config"})
cache_dir: Optional[str] = field(default="cache", metadata={"help": "Path to the cache"})
train_key: Optional[str] = field(default=None, metadata={"help": "Optional subset name from training dataset"})
@dataclass
class AlbertTrainingArguments(TrainingArguments):
dataloader_num_workers: int = 0 # for streaming compatibility
per_device_train_batch_size: int = 4
per_device_eval_batch_size: int = 4
gradient_accumulation_steps: int = 1
seq_length: int = 512
pad_to_multiple_of: int = 8
learning_rate: float = 0.005
total_steps: int = 15625 # total number of collaborative SGD updates, used for learning rate schedule
warmup_steps: int = 3125
adam_epsilon: float = 1e-6
weight_decay: float = 0.01
max_grad_norm: float = 1.0
clamp_value: float = 10000.0
fp16: bool = False
fp16_opt_level: str = "O2"
do_train: bool = True
logging_steps: int = 100
max_steps: int = 10 ** 20
save_steps: int = 10 ** 20
save_total_limit: int = 2
output_dir: str = "outputs"