-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackward.py
63 lines (51 loc) · 2.14 KB
/
backward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import forward
import os
import ImgHandle as IMG
import random
BATCH_SIZE = 20
REGULARIZER = 0.001
STEPS = 10000
MOVING_AVERAGE_DECAY = 0.01
MODEL_SAVE_PATH="./model/"
MODEL_NAME="train_model"
FILE_NAME="Classification.xlsx"
def backward(data, label):
x = tf.placeholder(tf.float32, shape = (None, forward.INPUT_NODE))
y_ = tf.placeholder(tf.float32, shape = (None, forward.OUTPUT_NODE))
y = forward.forward(x, REGULARIZER)
global_step = tf.Variable(0, trainable=False)
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
cem = tf.reduce_mean(ce)
loss = cem + tf.add_n(tf.get_collection('losses'))
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss, global_step=global_step)
ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
ema_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
saver = tf.train.Saver()
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(STEPS):
start = (i*BATCH_SIZE)%len(data)
end = start+BATCH_SIZE
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: data[start:end], y_: label[start:end]})
if i % 100 == 0:
print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)
def main():
data, label = IMG.img_handle()
for i in range(len(data)):
x, y = random.randint(0, len(data)-1), random.randint(0, len(data)-1)
temp_data = data[x]
data[x] = data[y]
data[y] = temp_data
temp_label = label[x]
label[x] = label[y]
label[y] = temp_label
print(len(data), len(label))
backward(data, label)