-
Notifications
You must be signed in to change notification settings - Fork 481
/
Copy pathrbm_minibatch.dml
161 lines (127 loc) · 5.92 KB
/
rbm_minibatch.dml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# THIS SCRIPT IMPLEMENTS A TRAINING ALGORITHM FOR RESTRICTED BOLTZMANN MACHINES
#
# The training algorithm uses one-step Contrastive Divergence as suggested in [Bengio, Y., Learning Deep
# Architectures for AI, Foundations and Trends in Machine Learning, Volume 2 Issue 1, January 2009, p. 1–127].
# This implementation is slightly modified and uses mini-batches to speed up the learning process.
#
# INPUT PARAMETERS:
# --------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# --------------------------------------------------------------------------------------------
# X String ---- Location (on HDFS) to read the matrix X of feature vectors
# W String ---- Location to store the estimated RBM weights
# A String ---- Location to store the visible units bias vector
# B String ---- Location to store the hidden units bias vector
# batchsize Int 100 Number of observations in a single batch
# vis Int ---- Number of units in the visible layer. If not specified this will be set
# to the number of features in X
# hid Int 2 Number of units in the hidden layer
# epochs Int 10 Number of training epochs
# alpha Double 0.1 Learning rate
# fmt String text Matrix output format for W,A, and B, usually "text" or "csv"
# --------------------------------------------------------------------------------------------
# OUTPUT: Matrices containing weights and biases for the trained network
#
# OUTPUT SIZE: OUTPUT CONTENTS:
# W vis x hid RBM weights
# A 1 x vis Visible units biases
# B 1 x hid Hidden units biases
#
# To run an observation through the trained network use the rbm_run.dml script.
# Alternatively, you can simply compute P(h=1|v) using 1.0 / (1 + exp(-(X %*% W + B)) and
# get the values for the hidden layer by sampling from P(h=1|v).
#
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemDS.jar -f rbm.dml -nvargs X=$INPUT_DIR/X W=$OUTPUT_DIR/w A=$OUTPUT_DIR/a B=$OUTPUT_DIR/b batchsize=100 vis=10 hid=2 epochs=10 alpha=0.1
# -------------------------------------------- START OF RBM_MINIBATCH.DML --------------------------------------------
# Default values
learning_rate = ifdef($alpha, 0.1)
hidden_units_count = ifdef($hid, 2)
epoch_count = ifdef($epochs, 10)
batch_size = ifdef($batchsize, 100)
output_format = ifdef($fmt, "text")
print("BEGIN RBM MINIBATCH TRAINING SCRIPT")
print("Reading X...")
X = read($X)
# If number of visible units is not set,
# set it to the number of features in X
visible_units_count = ifdef ($vis, ncol(X))
# Get the total number of observations
N = nrow(X)
# Rescale to interval [0,1]
minX = min(X)
maxX = max(X)
if (minX < 0.0 | maxX > 1.0) {
print("X not in [0,1]. Rescaling...")
X = (X-minX)/(maxX-minX)
}
# Initialize a weight matrix (visible_units_count x hidden_units_count)
w = 0.1 * rand(rows = visible_units_count, cols = hidden_units_count, pdf = "uniform")
# Setup biases and initialize with 0
a = matrix(0, rows = 1, cols = visible_units_count) # Visible units bias
b = matrix(0, rows = 1, cols = hidden_units_count) # Hidden units bias
# Train the RBM
print("Training starts...")
for (epoch in 1:epoch_count) {
# Cumulative error for this epoch
epoch_err = 0
# Loop over each batch
for (batch_start in seq(1, N, batch_size)) {
# Compute the end of the batch for indexing
batch_end = batch_start + batch_size - 1
# Is this the last batch?
if (batch_end > N) {
batch_end = N
}
# Present the current minibatch to the visible layer
v_1 = X[batch_start:batch_end,]
# Get the number of observations in v_1
batch_N = nrow(v_1)
# POSITIVE PHASE
# Compute the probabilities P(h=1|v)
p_h1_given_v = 1.0 / (1 + exp(-(v_1 %*% w + b)))
# Sample from P(h=1|v)
h1 = p_h1_given_v > rand(rows = batch_N, cols = hidden_units_count)
# NEGATIVE PHASE
# Compute the probabilities P(v2=1|h1)
p_v2_given_h1 = 1.0 / (1 + exp(-(h1 %*% t(w) + a)))
# Compute the probabilities P(h2=1|v2)
p_h2_given_v2 = 1.0 / (1 + exp(-(p_v2_given_h1 %*% w + b)))
# UPDATE WEIGHTS AND BIASES
w = w + learning_rate * ((t(v_1) %*% p_h1_given_v - t(p_v2_given_h1) %*% p_h2_given_v2) / batch_N)
a = a + (learning_rate / batch_N) * (colSums(v_1) - colSums(p_v2_given_h1))
b = b + (learning_rate / batch_N) * (colSums(p_h1_given_v) - colSums(p_h2_given_v2))
# COMPUTE ERROR
minibatch_err = sum((v_1 - p_v2_given_h1) ^ 2) / batch_N
epoch_err = epoch_err + minibatch_err
}
print("Epoch " + epoch + " completed. Cumulative error is " + epoch_err)
}
# Save the RBM model
print("Saving RBM weights...")
write(w, $W, format=output_format)
print("Saving RBM biases...")
write(a, $A, format=output_format)
write(b, $B, format=output_format)
print("END OF RBM MINIBATCH TRAINING SCRIPT")
# -------------------------------------------- END OF RBM_MINIBATCH.DML --------------------------------------------