-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathdataset.py
59 lines (54 loc) · 2.16 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
from tensorflow.python.platform import gfile
import numpy as np
from PIL import Image
IMAGE_HEIGHT = 228
IMAGE_WIDTH = 304
TARGET_HEIGHT = 55
TARGET_WIDTH = 74
class DataSet:
def __init__(self, batch_size):
self.batch_size = batch_size
def csv_inputs(self, csv_file_path):
filename_queue = tf.train.string_input_producer([csv_file_path], shuffle=True)
reader = tf.TextLineReader()
_, serialized_example = reader.read(filename_queue)
filename, depth_filename = tf.decode_csv(serialized_example, [["path"], ["annotation"]])
# input
jpg = tf.read_file(filename)
image = tf.image.decode_jpeg(jpg, channels=3)
image = tf.cast(image, tf.float32)
# target
depth_png = tf.read_file(depth_filename)
depth = tf.image.decode_png(depth_png, channels=1)
depth = tf.cast(depth, tf.float32)
depth = tf.div(depth, [255.0])
#depth = tf.cast(depth, tf.int64)
# resize
image = tf.image.resize_images(image, (IMAGE_HEIGHT, IMAGE_WIDTH))
depth = tf.image.resize_images(depth, (TARGET_HEIGHT, TARGET_WIDTH))
invalid_depth = tf.sign(depth)
# generate batch
images, depths, invalid_depths = tf.train.batch(
[image, depth, invalid_depth],
batch_size=self.batch_size,
num_threads=4,
capacity= 50 + 3 * self.batch_size,
)
return images, depths, invalid_depths
def output_predict(depths, images, output_dir):
print("output predict into %s" % output_dir)
if not gfile.Exists(output_dir):
gfile.MakeDirs(output_dir)
for i, (image, depth) in enumerate(zip(images, depths)):
pilimg = Image.fromarray(np.uint8(image))
image_name = "%s/%05d_org.png" % (output_dir, i)
pilimg.save(image_name)
depth = depth.transpose(2, 0, 1)
if np.max(depth) != 0:
ra_depth = (depth/np.max(depth))*255.0
else:
ra_depth = depth*255.0
depth_pil = Image.fromarray(np.uint8(ra_depth[0]), mode="L")
depth_name = "%s/%05d.png" % (output_dir, i)
depth_pil.save(depth_name)