-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatchNormalizationBackward.m
31 lines (22 loc) · 1.11 KB
/
batchNormalizationBackward.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
function [dX,dBeta,dGamma] = batchNormalizationBackward(dZ, X, gamma, ...
epsilon, batchMean, invSqrtVarPlusEps, channelDim) %#ok<INUSL>
% Back-propagation using batch normalization layer on the host
% NB: invSqrtVarPlusEps is 1./sqrt(var(X) + epsilon)
% Copyright 2016-2018 The MathWorks, Inc.
% We need to take statistics over all dimensions except the activations
% (third dimension for 4-D array/fourth dimesnion for 5-D array)
reduceDims = [1:channelDim-1 channelDim+1:ndims(X)];
m = numel(X) ./ size(X, channelDim); % total number of elements in batch per activation
Xnorm = (X - batchMean) .* invSqrtVarPlusEps;
% Get the gradient of the function w.r.t the parameters beta and gamma.
dBeta = sum(dZ, reduceDims);
dGamma = sum(dZ .* Xnorm, reduceDims);
% Now get the gradient of the function w.r.t. input (x)
% See Ioffe & Szegedy, "Batch Normalization: Accelerating Deep Network
% Training by Reducing Internal Covariate Shift" for details.
factor = gamma .* invSqrtVarPlusEps;
factorScaled = factor ./ m;
dMean = dBeta .* factorScaled;
dVar = dGamma .* factorScaled;
dX = dZ .* factor - Xnorm .* dVar - dMean;
end