-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils.py
83 lines (69 loc) · 2.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
import logging
import os
import shutil
import matplotlib.pyplot as plt
class Params():
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, 'w') as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by params.dict['learning_rate']"""
return self.__dict__
def image_grid(
images,
rows=None,
cols=None,
fill: bool = True,
show_axes: bool = False,
rgb: bool = True,
):
"""
A util function for plotting a grid of images.
Args:
images: (N, H, W, 4) array of RGBA images
rows: number of rows in the grid
cols: number of columns in the grid
fill: boolean indicating if the space between images should be filled
show_axes: boolean indicating if the axes of the plots should be visible
rgb: boolean, If True, only RGB channels are plotted.
If False, only the alpha channel is plotted.
Returns:
None
"""
if (rows is None) != (cols is None):
raise ValueError("Specify either both rows and cols or neither.")
if rows is None:
rows = len(images)
cols = 1
gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
bleed = 0
fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
for ax, im in zip(axarr.ravel(), images):
if rgb:
# only render RGB channels
ax.imshow(im[..., :3])
else:
# only render Alpha channel
ax.imshow(im[..., 3])
if not show_axes:
ax.set_axis_off()