-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrnn_test.py
43 lines (34 loc) · 1.03 KB
/
rnn_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import RecurrentNeuralNetwork.recurrentNeuralNetwork as RNN
import Training.teacher as Teacher
import Logger.consoleLogger as Logger
import numpy as np
rnn = RNN.RecurrentNeuralNetwork([2,4,1])
seq = []
out = []
seq.append(np.array([[0],[1]]))
seq.append(np.array([[0],[1]]))
seq.append(np.array([[0],[0]]))
seq.append(np.array([[1],[0]]))
seq.append(np.array([[1],[0]]))
seq.append(np.array([[1],[1]]))
seq.append(np.array([[1],[1]]))
seq.append(np.array([[0],[1]]))
out.append(np.array([[0]]))
out.append(np.array([[1]]))
out.append(np.array([[0]]))
out.append(np.array([[1]]))
out.append(np.array([[0]]))
out.append(np.array([[1]]))
out.append(np.array([[1]]))
out.append(np.array([[0]]))
l = Logger.ConsoleLogger(rnn,([seq],[out]),([seq],[out]))
T = Teacher.Teacher(rnn, l)
T.add_weight_update(-0.05, Teacher.gradient_descent)
#T.add_weight_update(0.1, Teacher.momentum)
#T.add_weight_update(0.0001, Teacher.weight_decay)
for i in range(100000):
T.train([seq],[out])
if i%500 == 0:
l.log_training(i)
print rnn.predict(seq)
print out