Skip to content

Commit

Permalink
fix: prevent zero-probability events from being sampled in multinomia…
Browse files Browse the repository at this point in the history
…l function in Jax-backend (#177)

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
  • Loading branch information
SinanGncgl and ozankabak authored Feb 4, 2025
1 parent d8dd269 commit e34d29a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
14 changes: 6 additions & 8 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,29 +513,27 @@ def multinomial(
replacement: whether to sample with replacement
"""
prng_key = self._get_prng_key(key)
input = jax.numpy.asarray(probs)
input = input / jax.numpy.sum(input, axis=-1, keepdims=True)
batch_size = input.shape[:-1]
logits = jax.numpy.log(jax.numpy.maximum(input, 1e-37))
probs = probs / jnp.sum(probs, axis=-1, keepdims=True)
logits = jnp.log(probs)

if replacement:
# Use categorical directly - much faster than choice
samples = jax.random.categorical(
prng_key,
logits, # avoid log(0)
shape=batch_size + (num_samples,),
logits,
shape=probs.shape[:-1] + (num_samples,),
)
else:
# TODO: This algorithm is not efficient for small num_samples
# consider more efficient algorithm

# For without replacement, use Gumbel-max trick
# This is much faster than using choice
z = jax.random.gumbel(prng_key, shape=input.shape + (num_samples,))
z = jax.random.gumbel(prng_key, shape=probs.shape + (num_samples,))
# Add log probabilities for Gumbel-max trick,
z = z + logits[..., None]
# Get top k indices
samples = jax.numpy.argsort(-z, axis=input.ndim - 1)[..., :num_samples, 0]
samples = jax.numpy.argsort(-z, axis=probs.ndim - 1)[..., :num_samples, 0]

return samples

Expand Down
39 changes: 39 additions & 0 deletions tests/scripts/test_backend_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,3 +2061,42 @@ def test_rand_uniform(self, backendcls, device, dtype):
assert not backend.any(output < 0)
assert not backend.any(output > 10)
assert list(output.shape) == fn_args[2:]


@pytest.mark.parametrize(
"backendcls, device, dtype", backends_with_device_dtype, ids=names
)
class TestMultinomial:
def test_multinomial_with_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.1, 0.3, 0.6]), 5]

output = fn(*fn_args, replacement=True)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_without_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.2, 0.3, 0.5]), 3]

output = fn(*fn_args, replacement=False)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_invalid_probabilities(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial

array_size = 1000
fn_args = [backend.array([0.0] * 500 + [1.0] + [0.0] * (array_size - 501)), 100]

output = fn(*fn_args, replacement=True)

assert list(output.shape) == [fn_args[1]]
assert backend.all(output == 500)

0 comments on commit e34d29a

Please sign in to comment.