diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 80c9cde2..59b44ef0 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -513,17 +513,15 @@ def multinomial( replacement: whether to sample with replacement """ prng_key = self._get_prng_key(key) - input = jax.numpy.asarray(probs) - input = input / jax.numpy.sum(input, axis=-1, keepdims=True) - batch_size = input.shape[:-1] - logits = jax.numpy.log(jax.numpy.maximum(input, 1e-37)) + probs = probs / jnp.sum(probs, axis=-1, keepdims=True) + logits = jnp.log(probs) if replacement: # Use categorical directly - much faster than choice samples = jax.random.categorical( prng_key, - logits, # avoid log(0) - shape=batch_size + (num_samples,), + logits, + shape=probs.shape[:-1] + (num_samples,), ) else: # TODO: This algorithm is not efficient for small num_samples @@ -531,11 +529,11 @@ def multinomial( # For without replacement, use Gumbel-max trick # This is much faster than using choice - z = jax.random.gumbel(prng_key, shape=input.shape + (num_samples,)) + z = jax.random.gumbel(prng_key, shape=probs.shape + (num_samples,)) # Add log probabilities for Gumbel-max trick, z = z + logits[..., None] # Get top k indices - samples = jax.numpy.argsort(-z, axis=input.ndim - 1)[..., :num_samples, 0] + samples = jax.numpy.argsort(-z, axis=probs.ndim - 1)[..., :num_samples, 0] return samples diff --git a/tests/scripts/test_backend_fns.py b/tests/scripts/test_backend_fns.py index c2c55de7..56b668d8 100644 --- a/tests/scripts/test_backend_fns.py +++ b/tests/scripts/test_backend_fns.py @@ -2061,3 +2061,42 @@ def test_rand_uniform(self, backendcls, device, dtype): assert not backend.any(output < 0) assert not backend.any(output > 10) assert list(output.shape) == fn_args[2:] + + +@pytest.mark.parametrize( + "backendcls, device, dtype", backends_with_device_dtype, ids=names +) +class TestMultinomial: + def test_multinomial_with_replacement(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + fn_args = [backend.array([0.1, 0.3, 0.6]), 5] + + output = fn(*fn_args, replacement=True) + + assert not backend.any(output < 0) + assert not backend.any(output >= len(fn_args[0])) + assert list(output.shape) == [fn_args[1]] + + def test_multinomial_without_replacement(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + fn_args = [backend.array([0.2, 0.3, 0.5]), 3] + + output = fn(*fn_args, replacement=False) + + assert not backend.any(output < 0) + assert not backend.any(output >= len(fn_args[0])) + assert list(output.shape) == [fn_args[1]] + + def test_multinomial_invalid_probabilities(self, backendcls, device, dtype): + backend = backendcls(device=device, dtype=dtype) + fn = backend.multinomial + + array_size = 1000 + fn_args = [backend.array([0.0] * 500 + [1.0] + [0.0] * (array_size - 501)), 100] + + output = fn(*fn_args, replacement=True) + + assert list(output.shape) == [fn_args[1]] + assert backend.all(output == 500)