-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
75 lines (61 loc) · 2.57 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from collections import namedtuple
import json
import argparse
class Config(object):
"""Configuration module."""
def __init__(self, config, log_level):
self.paths = ""
self.config_path = config
self.log_level = log_level
# Load config file
with open(config, 'r') as config:
self.config = json.load(config)
# Extract configuration
self.extract()
def extract(self):
config = self.config
# -- Clients --
fields = ['total', 'per_round', 'label_distribution', 'display_data_distribution']
defaults = (0, 0, 'uniform', False)
params = [config['clients'].get(field, defaults[i])
for i, field in enumerate(fields)]
self.clients = namedtuple('clients', fields)(*params)
assert self.clients.per_round <= self.clients.total
# -- Data --
fields = ['loading', 'partition', 'IID', 'bias', 'shard', 'server_split']
defaults = ('static', 0, True, None, None, 0.1)
params = [config['data'].get(field, defaults[i])
for i, field in enumerate(fields)]
self.data = namedtuple('data', fields)(*params)
# Determine correct data loader
assert self.data.IID ^ bool(self.data.bias) ^ bool(self.data.shard)
if self.data.IID:
self.loader = 'basic'
elif self.data.bias:
self.loader = 'bias'
elif self.data.shard:
self.loader = 'shard'
# -- Federated learning --
fields = ['rounds', 'target_accuracy', 'aggregation', 'prune_level_setter']
defaults = (0, None, "normal", "greedy")
params = [config['federated_learning'].get(field, defaults[i])
for i, field in enumerate(fields)]
self.fl = namedtuple('fl', fields)(*params)
# -- Server --
self.server = config['server']
self.lottery = config['lottery']
# -- Lottery --
def load_parser(json_dict):
t_args = argparse.Namespace()
t_args.__dict__.update(json_dict)
return t_args
# load arguments from config
self.lottery_args = load_parser(self.lottery)
# -- Paths --
fields = ['data', 'model', 'reports']
defaults = ('./data', './models', None)
params = [config['paths'].get(field, defaults[i])
for i, field in enumerate(fields)]
# Set specific model path
params[fields.index('model')] += '/' + self.lottery_args.model_name
self.paths = namedtuple('paths', fields)(*params)