-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathlaunch.py
72 lines (64 loc) · 2.33 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from dataclasses import asdict
from datetime import datetime
import argparse
import os
import equinox as eqx
import gcsfs
import jax
import json
import wandb
from jax.experimental.multihost_utils import sync_global_devices
from src.train import train
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--rundir", type=str)
parser.add_argument("--debug", action="store_true")
parser.add_argument("--multihost", action="store_true")
cmd_args = parser.parse_args()
if cmd_args.multihost:
jax.distributed.initialize()
# load config from src.configs
config = getattr(
__import__("src.configs", fromlist=[cmd_args.config]), cmd_args.config
).config
if cmd_args.rundir is not None:
config.rundir = cmd_args.rundir
elif not cmd_args.debug:
assert not cmd_args.multihost, "Multihost must prespecify rundir."
config.rundir = os.path.join(
"outputs", datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
)
if cmd_args.debug:
config.debug = True
if jax.process_index() == 0: # Wandb and config setup
wandb_id = None
config_dict = asdict(config)
if not cmd_args.debug:
print(f"Writing to {config.rundir}")
if config.rundir.startswith("gs://"):
print("Using GCS filesystem")
fs = gcsfs.GCSFileSystem()
fopen, exists = fs.open, fs.exists
else:
print("Using local filesystem")
config.rundir = os.path.abspath(config.rundir)
fs, fopen, exists = os, open, os.path.exists
# make sure the directory exists
fs.makedirs(config.rundir, exist_ok=True)
# write config as json
with fopen(os.path.join(config.rundir, "config.json"), "w") as f:
f.write(json.dumps(config_dict))
# Load wandb id or write it, for proper wandb resuming.
wandb_id_path = os.path.join(config.rundir, "wandb_id.txt")
if exists(wandb_id_path):
with fopen(wandb_id_path, "r") as f:
wandb_id = f.read()
else:
wandb_id = wandb.util.generate_id()
with fopen(wandb_id_path, "w") as f:
f.write(wandb_id)
wandb.init(project="midgpt", id=wandb_id, resume="allow", config=config_dict)
if cmd_args.multihost:
sync_global_devices("end_wandb_init")
eqx.tree_pprint(config)
train(config)