-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPICNN
225 lines (185 loc) · 10.8 KB
/
PICNN
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# The following is a class to construct a partially input convex neural network. It will take inputs as a datatuple of the follwoing three items: tensor of shape [batch_size, n] where the n entries of
# each row are the randomly generated components of the 'x' vector, another input tensor of shape [batch_size, k] where the k entries of each row represent the components for the 'y' vector,
# and a tensor of shape [batch_size, 1] representing the output of the underlying function. The network's structure enforces convexity upon itself in respect to each of the components of the
# y input vector. Its current configuration uses softplus activation functions for each layer and has its network design described in this paper:
# https://proceedings.mlr.press/v70/amos17b/amos17b.pdf
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import numpy as np
'''The MPS device enables high-performance training on GPU for MacOS devices with Metal programming framework.
It introduces a new device to map Machine Learning computational graphs and primitives on highly efficient
Metal Performance Shaders Graph framework and tuned kernels provided by Metal Performance Shaders framework respectively.'''
# Check that MPS is available
#if not torch.backends.mps.is_available():
#if not torch.backends.mps.is_built():
#print("MPS not available because the current PyTorch install was not "
#"built with MPS enabled.")
#else:
#print("MPS not available because the current MacOS version is not 12.3+ "
#"and/or you do not have an MPS-enabled device on this machine.")
#else:
#mps_device = torch.device("mps")
#print("Running on the GPU")
class PICNN(nn.Module):
def __init__(self, x_input_size, y_input_size, u_hl_dim, z_hl_dim, num_hidden_layers, output_size):
super(PICNN, self).__init__()
self.x_input_size = x_input_size
self.y_input_size = y_input_size
self.u_hl_dim = u_hl_dim
self.z_hl_dim = z_hl_dim
self.num_hidden_layers = num_hidden_layers
self.output_size = output_size
self.weight_initialization_description = "" # need someone to assist here
# We create a simple forward linear flow through the z_i layers
# nn.Linear applies a linear transformation to incoming data: output = input*weight + bias
# Based on the parameterization given in the paper, we will set bias=False for each of these layers
# Additionally, based on the paper, these layers (mapping from indices i=0, ..., k-1) will take in a hadamard product of z_i and the output of u_map_to_size_z (see below for this variable and later implementation)
z_flow = list()
z_flow.append(nn.Linear(y_input_size, z_hl_dim, bias=False))
for layer in range(num_hidden_layers - 1):
z_flow.append(nn.Linear(z_hl_dim, z_hl_dim, bias=False))
z_flow.append(nn.Linear(z_hl_dim, output_size))
self.z_traversal = nn.ModuleList(z_flow)
self.relu_weights() # Ensures that the z-flow weights is non-negative, which is a constraint from Proposition 1 of the ICNN paper
# The connection between the input vector 'y' and each of the z_i for i=2, ..., k
# Based on the parameterization given in the paper, we will set bias=False for each of these layers
# Additionally, based on the paper, these layers (mapping from indices i=0, ..., k-1) will take in a hadamard product of y_i and the output of u_map_to_size_y (see below for this variable and later implementation)
y_flow = list()
for layer in range(num_hidden_layers - 1):
y_flow.append(nn.Linear(y_input_size, z_hl_dim, bias=False))
y_flow.append(nn.Linear(y_input_size, output_size, bias=False))
self.y_traversal = nn.ModuleList(y_flow)
# u_forTheZLaye_flow is a list of linear layers representing the forward linear connections between u_i and z_i+1 for all i <= k-1, where, by convention u_0 = x
u_for_the_zlayer_flow = list()
u_for_the_zlayer_flow.append(nn.Linear(x_input_size, z_hl_dim, bias=True))
for layer in range(num_hidden_layers - 1):
u_for_the_zlayer_flow.append(nn.Linear(u_hl_dim, z_hl_dim, bias=True))
u_for_the_zlayer_flow.append(nn.Linear(u_hl_dim, output_size, bias=True))
self.u_for_the_zlayer_traversal = nn.ModuleList(u_for_the_zlayer_flow)
# u_map_to_size_y is from the paper and constructed for the hadamard product
u_map_to_size_y = list()
u_map_to_size_y.append(nn.Linear(x_input_size, y_input_size, bias=True))
for layer in range(num_hidden_layers):
u_map_to_size_y.append(nn.Linear(u_hl_dim, y_input_size, bias=True))
self.u_map_to_size_y_traversal = nn.ModuleList(u_map_to_size_y)
# u_map_to_size_z is from the paper and constructed for the hadamard product
u_map_to_size_z = list()
u_map_to_size_z.append(nn.Linear(x_input_size, y_input_size, bias=True))
for layer in range(num_hidden_layers):
u_map_to_size_z.append(nn.Linear(u_hl_dim, z_hl_dim, bias=True))
self.u_map_to_size_z_traversal = nn.ModuleList(u_map_to_size_z)
# u_flow is simply the forward flow of the u's, the upper, independent track from the z_traversal
u_flow = list()
u_flow.append(nn.Linear(x_input_size, u_hl_dim, bias=True))
for layer in range(num_hidden_layers - 1):
u_flow.append(nn.Linear(u_hl_dim, u_hl_dim, bias=True))
u_flow.append(nn.Linear(u_hl_dim, 1, bias=False)) # This is a dud layer that does nothing and won't be called, to make the zip for the for loop in the forward function work properly
self.u_traversal = nn.ModuleList(u_flow)
self.activation_function = nn.Softplus() # Softplus is convex and nondecreasing thus satisfies Proposition 1's criteria for network convexity
#self.initialize_weights()
def forward(self, x, y): # The following forward method is just as described in the paper
z = self.activation_function(self.z_traversal[0](y * self.u_map_to_size_z_traversal[0](x)) + self.z_traversal[0](y * self.u_map_to_size_z_traversal[0](x)) + self.u_for_the_zlayer_traversal[0](x))
u = self.activation_function(self.u_traversal[0](x))
for z_flow, y_flow, u_for_z_flow, u_map_to_y, u_map_to_z, u_flow in zip(self.z_traversal[1:], self.y_traversal[:], self.u_for_the_zlayer_traversal[1:], self.u_map_to_size_y_traversal[1:], self.u_map_to_size_z_traversal[1:], self.u_traversal[1:]):
z = self.activation_function(z_flow(z * u_map_to_z(u)) + y_flow(y * u_map_to_y(u)) + u_for_z_flow(u))
u = self.activation_function(u_flow(u))
return z
def relu_weights(self):
for z_flow in self.z_traversal:
with torch.no_grad():
z_flow.weight[z_flow.weight < 0] = 0.0000001
# picnn = PICNN(x_input_size, y_input_size, u_hl_dim, z_hl_dim, num_hidden_layers, output_size)
picnn = PICNN(2, 3, 6, 6, 4, 1)
num_of_batches = 4096
batch_size = 16
trainset = list()
x_dim = 2
y_dim = 3
for _ in range(num_of_batches):
X = np.random.rand(batch_size, x_dim)
Y = np.random.rand(batch_size, y_dim)
if x_dim > y_dim:
Y_new = np.append(Y, np.zeros([batch_size, x_dim - y_dim]), axis=1)
X_new = X
elif y_dim > x_dim:
X_new = np.append(X, np.zeros([batch_size, y_dim - x_dim]), axis=1)
Y_new = Y
else:
X_new = X
Y_new = Y
actual_output_without_noise = np.array([(X_new**2 + Y_new**2).sum(axis=1)]).T # Should be the norm of the x and y row vector stitched together
actual_output = np.array([[a + np.random.randn() * 0.01 for list in actual_output_without_noise for a in list]]).T
data_tuple = X, Y, actual_output
trainset.append(data_tuple)
import matplotlib.pyplot as plt
mse_loss = nn.MSELoss()
# learning rate question
optimizer = optim.Adam(picnn.parameters(), lr=0.0005)
batch_num = 0
batch_graph =[]
for data in trainset:
output = picnn.forward(torch.tensor(data[0]).float(), torch.tensor(data[1]).float())
loss = mse_loss(output, torch.tensor(data[-1]).float())
loss.backward()
optimizer.step()
picnn.relu_weights()
# creates a list of points where the x-value is the batch number and y-value is the loss
batch_num += 1
batch_graph.append([batch_num,loss])
# plots the loss for each batch and displays
plt.figure(figsize = (9, 9))
batch_plot_X = [(point[0]) for point in batch_graph]
batch_plot_Y = [float((point[1]))for point in batch_graph]
plt.plot(batch_plot_X, batch_plot_Y, label = 'Loss over batches in Training Set')
plt.legend()
plt.show()
print(batch_plot_Y)
# Interface for hyperparameter tuning
from tabulate import tabulate
print('********** Test Number: ', np.random.randint(0,100000), ' **********\n\n')
print('batch size = ', batch_size, ' # of batches = ', num_of_batches, ' Total Training Examples = ', batch_size * num_of_batches)
print('Loss Function: ', mse_loss)
print('Optimizer: ', optimizer)
print('Weight Initialization: ', picnn.weight_initialization_description)
print('Network Design: ', picnn)
for z_flow in picnn.z_traversal:
assert all(a >= 0 for row in z_flow.weight for a in row)
print('Network is indeed convex, its structure is as described in the paper')
# Below we simply generate the test data and the desired output of the network (x and x^2)
test_batch_size = 1000
X = np.random.rand(test_batch_size, x_dim)
Y = np.random.rand(test_batch_size, y_dim)
if x_dim > y_dim:
Y_new = np.append(Y, np.zeros([test_batch_size, x_dim - y_dim]), axis=1)
X_new = X
elif y_dim > x_dim:
X_new = np.append(X, np.zeros([test_batch_size, y_dim - x_dim]), axis=1)
Y_new = Y
else:
X_new = X
Y_new = Y
actual_test_output = np.array([(X_new**2 + Y_new**2).sum(axis=1)]).T # Should be the norm of the x and y row vector stitched together
test_output = picnn.forward(torch.tensor(X).float(), torch.tensor(Y).float())
table = np.zeros([test_batch_size, x_dim + y_dim + 1])
table[:, :x_dim] = X
table[:, x_dim:x_dim + y_dim] = Y
table[:, -1] = np.array([value for list in test_output.tolist() for value in list]) # Reformat the output so it can be placed in the table
tablelist = [[None] * (x_dim + y_dim + 1)] * (test_batch_size + 1)
tablelist[0] = ['i'] * (x_dim + y_dim + 1)
tablelist[1:] = table.tolist()
print(tabulate(tablelist, headers="firstrow", tablefmt="fancy_grid"))
loss = mse_loss(test_output, torch.tensor(actual_test_output).float())
print('\nLoss is:', loss)
for z_flow in picnn.z_traversal:
print(z_flow.weight)
sample_size = 100000
dim = 5
mat = np.random.rand(sample_size, dim)
norms = (mat ** 2).sum(axis=1)
expected_value = norms.sum() / sample_size
print(expected_value)
# picnn = PICNN(x_input_size, y_input_size, u_hl_dim, z_hl_dim, num_hidden_layers, output_size)
u = PICNN(1, 4, 1, 8, 4, 1)