Skip to content

Commit

Permalink
Fix average pooling in Torch backend (keras-team#372)
Browse files Browse the repository at this point in the history
* fix pooling

* fix failed test
  • Loading branch information
chenmoneygithub authored Jun 17, 2023
1 parent d953688 commit 76c2983
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 18 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax]
backend: [tensorflow, jax, torch]
name: Run tests
runs-on: ubuntu-latest
env:
Expand Down
49 changes: 40 additions & 9 deletions keras_core/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,20 +250,51 @@ def average_pool(
data_format = standardize_data_format(data_format)
if data_format == "channels_last":
inputs = _transpose_spatial_inputs(inputs)

padding_value = 0
if padding == "same":
# Torch does not natively support `"same"` padding, we need to manually
# apply the right amount of padding to `inputs`.
inputs = _apply_same_padding(
inputs, pool_size, strides, operation_type="pooling"
)
spatial_shape = inputs.shape[2:]
num_spatial_dims = len(spatial_shape)
padding_value = []
uneven_padding = []

for i in range(num_spatial_dims):
padding_size = _compute_padding_length(
spatial_shape[i], pool_size[i], strides[i]
)
# Torch only supports even padding on each dim, to replicate the
# behavior of "same" padding of `tf.keras` as much as possible,
# we need to pad evenly using the shorter padding.
padding_value.append(padding_size[0])
if padding_size[0] != padding_size[1]:
# Handle unequal padding.
# `torch.nn.pad` sets padding value in the reverse order.
uneven_padding = [0, 1] + uneven_padding
inputs = tnn.pad(inputs, uneven_padding)

if num_spatial_dims == 1:
outputs = tnn.avg_pool1d(inputs, kernel_size=pool_size, stride=strides)
outputs = tnn.avg_pool1d(
inputs,
kernel_size=pool_size,
stride=strides,
padding=padding_value,
count_include_pad=False,
)
elif num_spatial_dims == 2:
outputs = tnn.avg_pool2d(inputs, kernel_size=pool_size, stride=strides)
outputs = tnn.avg_pool2d(
inputs,
kernel_size=pool_size,
stride=strides,
padding=padding_value,
count_include_pad=False,
)
elif num_spatial_dims == 3:
outputs = tnn.avg_pool3d(inputs, kernel_size=pool_size, stride=strides)
outputs = tnn.avg_pool3d(
inputs,
kernel_size=pool_size,
stride=strides,
padding=padding_value,
count_include_pad=False,
)
else:
raise ValueError(
"Inputs to pooling op must have ndim=3, 4 or 5, "
Expand Down
5 changes: 5 additions & 0 deletions keras_core/layers/layer_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest

from keras_core import backend
from keras_core import layers
Expand Down Expand Up @@ -356,6 +357,10 @@ def call(self, x, training=False):
y = layer(x)
self.assertEqual(ops.min(y), 1)

@pytest.mark.skipIf(
backend.backend() == "torch",
reason="Torch backend has unimplemtned ops for mixed precision on CPU.",
)
def test_mixed_precision(self):
x = np.ones((4, 4))

Expand Down
95 changes: 92 additions & 3 deletions keras_core/layers/pooling/average_pooling_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_core import backend
from keras_core import layers
from keras_core import testing

Expand Down Expand Up @@ -108,7 +110,6 @@ def test_average_pooling3d(
class AveragePoolingCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2,), (2,), "valid", "channels_last"),
)
def test_average_pooling1d(self, pool_size, strides, padding, data_format):
Expand All @@ -131,9 +132,39 @@ def test_average_pooling1d(self, pool_size, strides, padding, data_format):
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@parameterized.parameters(
(2, 1, "same", "channels_first"),
((2,), (2,), "same", "channels_last"),
)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="Same padding in Torch backend produces different results.",
)
def test_average_pooling1d_same_padding(
self, pool_size, strides, padding, data_format
):
inputs = np.arange(24, dtype="float32").reshape((2, 3, 4))

layer = layers.AveragePooling1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling1D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)

outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@parameterized.parameters(
(2, 1, "valid", "channels_last"),
((2, 3), (2, 2), "same", "channels_last"),
((2, 3), (2, 2), "valid", "channels_last"),
)
def test_average_pooling2d(self, pool_size, strides, padding, data_format):
inputs = np.arange(16, dtype="float32").reshape((1, 4, 4, 1))
Expand All @@ -154,9 +185,37 @@ def test_average_pooling2d(self, pool_size, strides, padding, data_format):
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@parameterized.parameters(
(2, (2, 1), "same", "channels_last"),
((2, 3), (2, 2), "same", "channels_last"),
)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="Same padding in Torch backend produces different results.",
)
def test_average_pooling2d_same_padding(
self, pool_size, strides, padding, data_format
):
inputs = np.arange(16, dtype="float32").reshape((1, 4, 4, 1))
layer = layers.AveragePooling2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling2D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)

outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@parameterized.parameters(
(2, 1, "valid", "channels_last"),
(2, 1, "same", "channels_first"),
((2, 3, 2), (2, 2, 1), "valid", "channels_last"),
)
def test_average_pooling3d(self, pool_size, strides, padding, data_format):
Expand All @@ -178,3 +237,33 @@ def test_average_pooling3d(self, pool_size, strides, padding, data_format):
outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)

@parameterized.parameters(
(2, 1, "same", "channels_first"),
((2, 3, 2), (2, 2, 1), "same", "channels_last"),
)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="Same padding in Torch backend produces different results.",
)
def test_average_pooling3d_same_padding(
self, pool_size, strides, padding, data_format
):
inputs = np.arange(240, dtype="float32").reshape((2, 3, 4, 5, 2))

layer = layers.AveragePooling3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)
tf_keras_layer = tf.keras.layers.AveragePooling3D(
pool_size=pool_size,
strides=strides,
padding=padding,
data_format=data_format,
)

outputs = layer(inputs)
expected = tf_keras_layer(inputs)
self.assertAllClose(outputs, expected)
21 changes: 16 additions & 5 deletions keras_core/operations/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,24 +725,35 @@ def test_max_pool(self):
tf.nn.max_pool2d(x, 2, (2, 1), padding="SAME"),
)

def test_average_pool(self):
def test_average_pool_valid_padding(self):
# Test 1D max pooling.
x = np.arange(120, dtype=float).reshape([2, 20, 3])
self.assertAllClose(
knn.average_pool(x, 2, 1, padding="valid"),
tf.nn.avg_pool1d(x, 2, 1, padding="VALID"),
)
self.assertAllClose(
knn.average_pool(x, 2, 2, padding="same"),
tf.nn.avg_pool1d(x, 2, 2, padding="SAME"),
)

# Test 2D max pooling.
x = np.arange(540, dtype=float).reshape([2, 10, 9, 3])
self.assertAllClose(
knn.average_pool(x, 2, 1, padding="valid"),
tf.nn.avg_pool2d(x, 2, 1, padding="VALID"),
)

@pytest.mark.skipif(
backend.backend() == "torch",
reason="Torch outputs differently from TF when using `same` padding.",
)
def test_average_pool_same_padding(self):
# Test 1D max pooling.
x = np.arange(120, dtype=float).reshape([2, 20, 3])
self.assertAllClose(
knn.average_pool(x, 2, 2, padding="same"),
tf.nn.avg_pool1d(x, 2, 2, padding="SAME"),
)

# Test 2D max pooling.
x = np.arange(540, dtype=float).reshape([2, 10, 9, 3])
self.assertAllClose(
knn.average_pool(x, 2, (2, 1), padding="same"),
tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"),
Expand Down

0 comments on commit 76c2983

Please sign in to comment.