-
Notifications
You must be signed in to change notification settings - Fork 60
/
Copy pathselflow_model.py
130 lines (105 loc) · 6.37 KB
/
selflow_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
from __future__ import division, print_function, absolute_import
import tensorflow as tf
import numpy as np
import os
import sys
import time
import cv2
from six.moves import xrange
from scipy import misc, io
from tensorflow.contrib import slim
import matplotlib.pyplot as plt
from network import pyramid_processing
from datasets import BasicDataset
from utils import average_gradients, lrelu, occlusion, rgb_bgr
from data_augmentation import flow_resize
from flowlib import flow_to_color, write_flo
from warp import tf_warp
class SelFlowModel(object):
def __init__(self, batch_size=8, iter_steps=1000000, initial_learning_rate=1e-4, decay_steps=2e5,
decay_rate=0.5, is_scale=True, num_input_threads=4, buffer_size=5000,
beta1=0.9, num_gpus=1, save_checkpoint_interval=5000, write_summary_interval=200,
display_log_interval=50, allow_soft_placement=True, log_device_placement=False,
regularizer_scale=1e-4, cpu_device='/cpu:0', save_dir='KITTI', checkpoint_dir='checkpoints',
model_name='model', sample_dir='sample', summary_dir='summary', training_mode="no_distillation",
is_restore_model=False, restore_model='./models/KITTI/no_census_no_occlusion',
dataset_config={}, self_supervision_config={}):
self.batch_size = batch_size
self.iter_steps = iter_steps
self.initial_learning_rate = initial_learning_rate
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.is_scale = is_scale
self.num_input_threads = num_input_threads
self.buffer_size = buffer_size
self.beta1 = beta1
self.num_gpus = num_gpus
self.save_checkpoint_interval = save_checkpoint_interval
self.write_summary_interval = write_summary_interval
self.display_log_interval = display_log_interval
self.allow_soft_placement = allow_soft_placement
self.log_device_placement = log_device_placement
self.regularizer_scale = regularizer_scale
self.training_mode = training_mode
self.is_restore_model = is_restore_model
self.restore_model = restore_model
self.dataset_config = dataset_config
self.self_supervision_config = self_supervision_config
self.shared_device = '/gpu:0' if self.num_gpus == 1 else cpu_device
assert(np.mod(batch_size, num_gpus) == 0)
self.batch_size_per_gpu = int(batch_size / np.maximum(num_gpus, 1))
self.save_dir = save_dir
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.checkpoint_dir = '/'.join([self.save_dir, checkpoint_dir])
if not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
self.model_name = model_name
if not os.path.exists('/'.join([self.checkpoint_dir, model_name])):
os.makedirs(('/'.join([self.checkpoint_dir, self.model_name])))
self.sample_dir = '/'.join([self.save_dir, sample_dir])
if not os.path.exists(self.sample_dir):
os.makedirs(self.sample_dir)
if not os.path.exists('/'.join([self.sample_dir, self.model_name])):
os.makedirs(('/'.join([self.sample_dir, self.model_name])))
self.summary_dir = '/'.join([self.save_dir, summary_dir])
if not os.path.exists(self.summary_dir):
os.makedirs(self.summary_dir)
if not os.path.exists('/'.join([self.summary_dir, 'train'])):
os.makedirs(('/'.join([self.summary_dir, 'train'])))
if not os.path.exists('/'.join([self.summary_dir, 'test'])):
os.makedirs(('/'.join([self.summary_dir, 'test'])))
def test(self, restore_model, save_dir, is_normalize_img=True):
dataset = BasicDataset(data_list_file=self.dataset_config['data_list_file'], img_dir=self.dataset_config['img_dir'], is_normalize_img=is_normalize_img)
save_name_list = dataset.data_list[:, -1]
iterator = dataset.create_one_shot_iterator(dataset.data_list, num_parallel_calls=self.num_input_threads)
batch_img0, batch_img1, batch_img2 = iterator.get_next()
img_shape = tf.shape(batch_img0)
h = img_shape[1]
w = img_shape[2]
new_h = tf.where(tf.equal(tf.mod(h, 64), 0), h, (tf.to_int32(tf.floor(h / 64) + 1)) * 64)
new_w = tf.where(tf.equal(tf.mod(w, 64), 0), w, (tf.to_int32(tf.floor(w / 64) + 1)) * 64)
batch_img0 = tf.image.resize_images(batch_img0, [new_h, new_w], method=1, align_corners=True)
batch_img1 = tf.image.resize_images(batch_img1, [new_h, new_w], method=1, align_corners=True)
batch_img2 = tf.image.resize_images(batch_img2, [new_h, new_w], method=1, align_corners=True)
flow_fw, flow_bw = pyramid_processing(batch_img0, batch_img1, batch_img2, train=False, trainable=False, is_scale=True)
flow_fw['full_res'] = flow_resize(flow_fw['full_res'], [h, w], method=1)
flow_bw['full_res'] = flow_resize(flow_bw['full_res'], [h, w], method=1)
flow_fw_color = flow_to_color(flow_fw['full_res'], mask=None, max_flow=256)
flow_bw_color = flow_to_color(flow_bw['full_res'], mask=None, max_flow=256)
restore_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
saver = tf.train.Saver(var_list=restore_vars)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)
saver.restore(sess, restore_model)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i in range(dataset.data_num):
np_flow_fw, np_flow_bw, np_flow_fw_color, np_flow_bw_color = sess.run([flow_fw['full_res'], flow_bw['full_res'], flow_fw_color, flow_bw_color])
misc.imsave('%s/flow_fw_color_%s.png' % (save_dir, save_name_list[i]), np_flow_fw_color[0])
misc.imsave('%s/flow_bw_color_%s.png' % (save_dir, save_name_list[i]), np_flow_bw_color[0])
write_flo('%s/flow_fw_%s.flo' % (save_dir, save_name_list[i]), np_flow_fw[0])
write_flo('%s/flow_bw_%s.flo' % (save_dir, save_name_list[i]), np_flow_bw[0])
print('Finish %d/%d' % (i+1, dataset.data_num))