-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathschema.py
51 lines (42 loc) · 1.42 KB
/
schema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from quinine import (tstring,
tinteger,
tfloat,
tboolean,
stdict,
tdict,
default,
required,
allowed,
nullable,
)
from funcy import merge
model_schema = {
"input_dim": merge(tinteger, required),
"psi_hidden_dim": merge(tinteger, required),
"rho_hidden_dim": merge(tinteger, required),
"psi_output_dim": merge(tinteger, required),
"rho_output_dim": merge(tinteger, required),
"xi_hidden_dim": merge(tinteger, required),
"output_dim": merge(tinteger, required),
"lr":merge(tfloat, required),
"num_epochs":merge(tinteger, required),
"activations": merge({"type": "list", "schema": tstring}, required),
"batch_norms": merge({"type": "list", "schema": tboolean}, required),
"layer_norms": merge({"type": "list", "schema": tboolean}, required),
"dropouts": merge({"type": "list", "schema": tfloat}, required),
"architecture": merge(tstring, required)
}
data_schema = {
"num_examples": merge(tinteger, required),
"data_dim": merge(tinteger, required),
"prompt_max_len_train": merge(tinteger, required),
"prompt_max_len_test": merge(tinteger, required),
"noise": merge(tboolean, required)
}
schema = {
"model": stdict(model_schema),
"data": stdict(data_schema),
"out_dir": merge(tstring, required),
"wandb_entity":merge(tstring, required)
}