-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
29 lines (24 loc) · 785 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Run an example.
from ann import data, example2l
import cProfile,pstats
profiler = cProfile.Profile()
profiler.enable()
print('Reading dataset')
#uncomment to use sql database
# X,Y,X_test,Y_test = data.Mnist(path='data').from_sql('data/oopd_train.db', 'data/oopd_test.db')
X,Y,X_test,Y_test = data.Mnist(path='data').data()
print("Finished reading the dataset")
print("Training started")
i_d, m_d, o_d = 784, 100, 10
act = 'relu'
lr,epoch = 0.00001,1
net = example2l.Net(i_d,m_d,o_d,act, lr,epoch, X,Y,X_test,Y_test)
net.train(step_u = 64, step=1000)
print("Finished training")
import io
profiler.disable()
s = io.StringIO()
stats = pstats.Stats(profiler,stream=s).sort_stats('ncalls')
stats.print_stats()
with open('tests/profiler.txt', 'w+') as f:
f.write(s.getvalue())