From 5ec1d473616770cf04543c39f68d0b371843d962 Mon Sep 17 00:00:00 2001 From: jkshin94 <112931895+jkshin94@users.noreply.github.com> Date: Thu, 30 May 2024 02:03:28 +0900 Subject: [PATCH] Update elemwise_ops.py handle NaN values after _safe_rshift function. --- mx/elemwise_ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mx/elemwise_ops.py b/mx/elemwise_ops.py index 32d9e9b..fb9468b 100644 --- a/mx/elemwise_ops.py +++ b/mx/elemwise_ops.py @@ -41,7 +41,9 @@ def _safe_rshift(x, bits, exp): if exp is None: return x / (2**bits) else: - return x / (2**bits) * (2 ** exp) + out = x / (2**bits) * (2 ** exp) + out[torch.isnan(out)] = 0. + return out def _round_mantissa(A, bits, round, clamp=False): @@ -162,7 +164,7 @@ def _quantize_elemwise_core(A, bits, exp_bits, max_norm, round='nearest', if not custom_cuda: out[A == float("Inf")] = float("Inf") out[A == -float("Inf")] = -float("Inf") - out[A == float("NaN")] = float("NaN") + out[torch.isnan(A)] = float("NaN") if A_is_sparse: output = torch.sparse_coo_tensor(sparse_A.indices(), output,