diff --git a/tensorflow_addons/activations/tanhshrink.py b/tensorflow_addons/activations/tanhshrink.py index be41d0e9ee..0c78b78074 100644 --- a/tensorflow_addons/activations/tanhshrink.py +++ b/tensorflow_addons/activations/tanhshrink.py @@ -38,3 +38,7 @@ def tanhshrink(x: types.TensorLike) -> tf.Tensor: @tf.RegisterGradient("Addons>Tanhshrink") def _tanhshrink_grad(op, grad): return _activation_so.ops.addons_tanhshrink_grad(grad, op.inputs[0]) + + +def _tanhshrink_py(x): + return x - tf.math.tanh(x) diff --git a/tensorflow_addons/activations/tanhshrink_test.py b/tensorflow_addons/activations/tanhshrink_test.py index 137b7449bf..37658fe1cf 100644 --- a/tensorflow_addons/activations/tanhshrink_test.py +++ b/tensorflow_addons/activations/tanhshrink_test.py @@ -18,6 +18,7 @@ import numpy as np import tensorflow as tf from tensorflow_addons.activations import tanhshrink +from tensorflow_addons.activations.tanhshrink import _tanhshrink_py from tensorflow_addons.utils import test_utils @@ -26,22 +27,22 @@ class TanhshrinkTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ("float16", np.float16), ("float32", np.float32), ("float64", np.float64) ) - def test_tanhshrink(self, dtype): - x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) - expected_result = tf.constant( - [-1.0359724, -0.23840582, 0.0, 0.23840582, 1.0359724], dtype=dtype - ) - - self.assertAllCloseAccordingToType(tanhshrink(x), expected_result) - - @parameterized.named_parameters(("float32", np.float32), ("float64", np.float64)) - def test_theoretical_gradients(self, dtype): - # Only test theoretical gradients for float32 and float64 - # because of the instability of float16 while computing jacobian - x = tf.constant([-2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype) - - theoretical, numerical = tf.test.compute_gradient(tanhshrink, [x]) - self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4) + def test_same_as_py_func(self, dtype): + np.random.seed(1234) + for _ in range(20): + self.verify_funcs_are_equivalent(dtype) + + def verify_funcs_are_equivalent(self, dtype): + x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype) + x = tf.convert_to_tensor(x_np) + with tf.GradientTape(persistent=True) as t: + t.watch(x) + y_native = tanhshrink(x) + y_py = _tanhshrink_py(x) + self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4) + grad_native = t.gradient(y_native, x) + grad_py = t.gradient(y_py, x) + self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4) if __name__ == "__main__":