From d51dc93255c71f87628c840128d5a638d3b6bfbe Mon Sep 17 00:00:00 2001 From: Sinan Gencoglu Date: Mon, 3 Feb 2025 19:12:37 +0100 Subject: [PATCH] simplify log calculation and add unit tests --- .../with_autograd/jax_backend/backend.py | 4 +- tests/scripts/test_backend_fns.py | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 116eb54a..ced6d9db 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -515,9 +515,7 @@ def multinomial( prng_key = self._get_prng_key(key) probs = jnp.asarray(probs) probs = probs / jnp.sum(probs, axis=-1, keepdims=True) - - # Mask zero probabilities to avoid log(0) without adding small constants - logits = jnp.where(probs > 0, jnp.log(probs), -jnp.inf) + logits = jnp.log(probs) if replacement: # Use categorical directly - much faster than choice diff --git a/tests/scripts/test_backend_fns.py b/tests/scripts/test_backend_fns.py index 2f6997a4..5e248f78 100644 --- a/tests/scripts/test_backend_fns.py +++ b/tests/scripts/test_backend_fns.py @@ -2060,3 +2060,42 @@ def test_rand_uniform(self, backendcls, device, dtype): assert not backend.any(output < 0) assert not backend.any(output > 10) assert list(output.shape) == fn_args[2:] + + +@pytest.mark.parametrize( + "backendcls, device, dtype", backends_with_device_dtype, ids=names +) +class TestMultinomial: + def test_multinomial_with_replacement(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + fn_args = [backend.array([0.1, 0.3, 0.6]), 5] + + output = fn(*fn_args, replacement=True) + + assert not backend.any(output < 0) + assert not backend.any(output >= len(fn_args[0])) + assert list(output.shape) == [fn_args[1]] + + def test_multinomial_without_replacement(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + fn_args = [backend.array([0.2, 0.3, 0.5]), 3] + + output = fn(*fn_args, replacement=False) + + assert not backend.any(output < 0) + assert not backend.any(output >= len(fn_args[0])) + assert list(output.shape) == [fn_args[1]] + + def test_multinomial_invalid_probabilities(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + + array_size = 1000 + fn_args = [backend.array([0.0] * 500 + [1.0] + [0.0] * (array_size - 501)), 100] # Sample 100 times + + output = fn(*fn_args, replacement=True) + + assert list(output.shape) == [fn_args[1]] + assert backend.all(output == 500)