Skip to content

Commit

Permalink
BatchNorm1dFlat
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Dec 28, 2018
1 parent cc6673b commit fca888d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ of that change.
- `Learner.to_fp32()` to go back to FP32 precision mode
- `cont_cat_split` function to automatically get categorical/continuous variables (thanks to RealLankinen)
- Lots of new metrics thanks to Sven Becker: `mse/mean_squared_error`, `mae/mean_absolute_error`, `rmse/root_mean_squared_error`, `msle/ mean_squared_logarithmic_error`, `explained_variance`, `r2_score`, `top_k_accuracy`, `KappaScore`, `MatthewsCorreff`, `Precision`, `Recall`, `FBeta`
- `BatchNorm1dFlat` for using batchnorm in sequence models (e.g. RNNs, and their inputs and outputs)

### Changed:

Expand Down
15 changes: 12 additions & 3 deletions fastai/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

__all__ = ['AdaptiveConcatPool2d', 'BCEWithLogitsFlat', 'BCEFlat', 'MSELossFlat', 'CrossEntropyFlat', 'Debugger',
'Flatten', 'Lambda', 'PoolFlatten', 'ResizeBatch', 'bn_drop_lin', 'conv2d', 'conv2d_trans', 'conv_layer',
'embedding', 'simple_cnn', 'NormType', 'relu', 'batchnorm_2d', 'trunc_normal_',
'PixelShuffle_ICNR', 'icnr', 'NoopLoss', 'WassersteinLoss', 'SelfAttention',
'SequentialEx', 'MergeLayer', 'res_block', 'sigmoid_range', 'SigmoidRange', 'PartialLayer', 'FlattenedLoss']
'embedding', 'simple_cnn', 'NormType', 'relu', 'batchnorm_2d', 'trunc_normal_', 'PixelShuffle_ICNR', 'icnr',
'NoopLoss', 'WassersteinLoss', 'SelfAttention', 'SequentialEx', 'MergeLayer', 'res_block', 'sigmoid_range',
'SigmoidRange', 'PartialLayer', 'FlattenedLoss', 'BatchNorm1dFlat']

class Lambda(nn.Module):
"An easy way to create a pytorch layer for a simple `func`."
Expand Down Expand Up @@ -269,3 +269,12 @@ def embedding(ni:int,nf:int) -> nn.Module:
# See https://arxiv.org/abs/1711.09160
with torch.no_grad(): trunc_normal_(emb.weight, std=0.01)
return emb

class BatchNorm1dFlat(nn.BatchNorm1d):
"`nn.BatchNorm1d`, but first flattens leading dimensions"
def forward(self, x):
if x.dim()==2: return super().forward(x)
*f,l = x.shape
x = x.contiguous().view(-1,l)
return super().forward(x).view(*f,l)

0 comments on commit fca888d

Please sign in to comment.