Skip to content

Commit

Permalink
Added python implementation for GeLU (#1154)
Browse files Browse the repository at this point in the history
* average callback and optimizers

* fixed links

* added SWA

* removed update, changed alias, 4 spaces

* added _gelu_py, test

* Delete average_optimizers_callback.ipynb
  • Loading branch information
abhichou4 authored Feb 26, 2020
1 parent d9858a3 commit e926ecb
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow_addons/activations/gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================

import tensorflow as tf
import math

from tensorflow_addons.utils import types
from tensorflow_addons.utils.resource_loader import LazySO
Expand Down Expand Up @@ -49,3 +50,13 @@ def _gelu_grad(op, grad):
return _activation_so.ops.addons_gelu_grad(
grad, op.inputs[0], op.get_attr("approximate")
)


def _gelu_py(x: types.TensorLike, approximate: bool = True) -> tf.Tensor:
x = tf.convert_to_tensor(x)
if approximate:
pi = tf.cast(math.pi, x.dtype)
coeff = tf.cast(0.044715, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
else:
return 0.5 * x * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
20 changes: 20 additions & 0 deletions tensorflow_addons/activations/gelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import tensorflow as tf
from tensorflow_addons.activations import gelu
from tensorflow_addons.activations.gelu import _gelu_py
from tensorflow_addons.utils import test_utils


Expand Down Expand Up @@ -51,6 +52,25 @@ def test_theoretical_gradients(self, dtype):
)
self.assertAllCloseAccordingToType(theoretical, numerical, atol=1e-4)

@parameterized.named_parameters(("float32", np.float32), ("float64", np.float64))
def test_same_as_py_func(self, dtype):
np.random.seed(100)
for _ in range(20):
self.verify_funcs_are_equivalent(dtype)

def verify_funcs_are_equivalent(self, dtype):
x_np = np.random.uniform(-10, 10, size=(4, 4)).astype(dtype)
x = tf.convert_to_tensor(x_np)
for approximate in [True, False]:
with tf.GradientTape(persistent=True) as t:
t.watch(x)
y_native = gelu(x, approximate=approximate)
y_py = _gelu_py(x, approximate=approximate)
self.assertAllCloseAccordingToType(y_native, y_py, atol=1e-4)
grad_native = t.gradient(y_native, x)
grad_py = t.gradient(y_py, x)
self.assertAllCloseAccordingToType(grad_native, grad_py, atol=1e-4)


if __name__ == "__main__":
tf.test.main()

0 comments on commit e926ecb

Please sign in to comment.