-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_learning_experiment.py
103 lines (78 loc) · 2.64 KB
/
test_learning_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
""" Test suite for the learning_experiment.py file """
from learning_experiment import LearningExp
def test_LearningExp():
e1 = LearningExp.load("data/LearningExp_190501_S5_001_log.txt")
N = len(e1.data)
all_exposure_data = e1.get_exposure_data()
all_guessing_data = e1.get_guessing_data()
all_production_data = e1.get_guessing_data()
memorization_test_data = e1.get_memorization_test_data()
regularization_test_data = e1.get_regularization_test_data()
assert (
sum(
len(d)
for d in [
all_exposure_data,
all_guessing_data,
all_production_data,
memorization_test_data,
regularization_test_data,
]
)
== N
)
def test_LearningExp_empty_like():
e1 = LearningExp.load("data/LearningExp_190501_S5_001_log.txt")
e2 = LearningExp.empty_like(e1)
assert (e2.data.columns == e1.data.columns).all()
assert e2.data.index.names == e1.data.index.names
assert e2.info["_orig_info_"] == e1.info
assert len(e2.data) == 0
def test_LearningExp_save_load():
e1 = LearningExp.load("data/LearningExp_190501_S5_001_log.txt")
path = "/tmp/TMP-learningexp-saveload-test.txt"
e1.save(path)
e1_loaded = LearningExp.load(path)
assert e1_loaded.info == e1.info
assert len(e1_loaded.data) == len(e1.data)
for col in e1.data.columns:
print(col)
assert (e1_loaded.data[col].fillna("XXX") == e1.data[col].fillna("XXX")).all()
def test_LearningExp_append_results():
e1 = LearningExp.load("data/LearningExp_190501_S5_001_log.txt")
log = LearningExp.empty_like(e1)
mem_data = e1.get_memorization_test_data()
reg_data = e1.get_regularization_test_data()
mem_dummy_messages = ["wuseldusel"] * len(mem_data)
mem_dummy_correct = [0] * len(mem_data)
reg_dummy_messages = ["hupiflup"] * len(reg_data)
reg_dummy_correct = None
log.append_results(
100,
"MemorizationTest",
mem_data,
mem_dummy_messages,
correct_messages=mem_dummy_correct,
)
log.append_results(
100,
"RegularizationTest",
reg_data,
reg_dummy_messages,
correct_messages=reg_dummy_correct,
)
log.append_results(
101,
"MemorizationTest",
mem_data,
mem_dummy_messages,
correct_messages=mem_dummy_correct,
)
log.append_results(
101,
"RegularizationTest",
reg_data,
reg_dummy_messages,
correct_messages=reg_dummy_correct,
)
assert len(log.data) == (2 * len(mem_data) + 2 * len(reg_data))