-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmain.py
106 lines (85 loc) · 3.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import dataclasses
import json
import pathlib
import time
from graph import Graph
from spsa import Param, SpsaParams, SpsaTuner
from cutechess import CutechessMan
import copy
def cutechess_from_config(config_path: str) -> CutechessMan:
with open(config_path) as config_file:
config = json.load(config_file)
return CutechessMan(**config)
def params_from_config(config_path: str) -> list[Param]:
with open(config_path) as config_file:
config = json.load(config_file)
return [Param(name, **cfg) for name, cfg in config.items()]
def spsa_from_config(config_path: str):
with open(config_path) as config_file:
config = json.load(config_file)
return SpsaParams(**config)
def save_state(spsa: SpsaTuner):
save_file = "./tuner/state.json"
spsa_params = spsa.spsa
uci_params = spsa.uci_params
t = spsa.t
with open(save_file, "w") as save_file:
spsa_params = dataclasses.asdict(spsa_params)
uci_params = [dataclasses.asdict(
uci_param) for uci_param in uci_params]
json.dump({"t": t, "spsa_params": spsa_params,
"uci_params": uci_params}, save_file)
def main():
state_path = pathlib.Path("./tuner/state.json")
t = 0
if state_path.is_file():
print("hey")
with open(state_path) as state:
state_dict = json.load(state)
params = [Param(cfg["name"], cfg["value"], cfg["min_value"], cfg["max_value"], cfg["step"])
for cfg in state_dict["uci_params"]]
spsa_params = SpsaParams(**state_dict["spsa_params"])
t = state_dict["t"]
else:
params = params_from_config("config.json")
spsa_params = spsa_from_config("spsa.json")
cutechess = cutechess_from_config("cutechess.json")
spsa = SpsaTuner(spsa_params, params, cutechess)
spsa.t = t
graph = Graph()
avg_time = 0
start_t = t
print("Initial state: ")
for param in spsa.params:
print(param)
print()
try:
while True:
start = time.time()
spsa.step()
avg_time += time.time() - start
graph.update(spsa.t, copy.deepcopy(spsa.params))
graph.save("graph.png")
if ((spsa.t / cutechess.games) % cutechess.save_rate) == 0:
print("Saving state...")
save_state(spsa)
print(
f"iterations: {int(spsa.t / cutechess.games)} ({((avg_time / (spsa.t - start_t)) * cutechess.games):.2f}s per iter)")
print(
f"games: {spsa.t} ({(avg_time / (spsa.t - start_t)):.2f}s per game)")
for param in spsa.params:
print(param)
print()
finally:
print("Saving state...")
save_state(spsa)
print("Final results: ")
print(
f"iterations: {int(spsa.t / cutechess.games)} ({((avg_time / (spsa.t - start_t)) * cutechess.games):.2f}s per iter)")
print(
f"games: {spsa.t} ({(avg_time / (spsa.t - start_t)):.2f}s per game)")
print("Final parameters: ")
for param in spsa.params:
print(param)
if __name__ == "__main__":
main()