Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent zero-probability events from being sampled in multinomial function in Jax-backend #177

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,29 +513,28 @@ def multinomial(
replacement: whether to sample with replacement
"""
prng_key = self._get_prng_key(key)
input = jax.numpy.asarray(probs)
input = input / jax.numpy.sum(input, axis=-1, keepdims=True)
batch_size = input.shape[:-1]
logits = jax.numpy.log(jax.numpy.maximum(input, 1e-37))
probs = jnp.asarray(probs)
probs = probs / jnp.sum(probs, axis=-1, keepdims=True)
logits = jnp.log(probs)

if replacement:
# Use categorical directly - much faster than choice
samples = jax.random.categorical(
prng_key,
logits, # avoid log(0)
shape=batch_size + (num_samples,),
shape=probs.shape[:-1] + (num_samples,),
)
else:
# TODO: This algorithm is not efficient for small num_samples
# consider more efficient algorithm

# For without replacement, use Gumbel-max trick
# This is much faster than using choice
z = jax.random.gumbel(prng_key, shape=input.shape + (num_samples,))
z = jax.random.gumbel(prng_key, shape=probs.shape + (num_samples,))
# Add log probabilities for Gumbel-max trick,
z = z + logits[..., None]
# Get top k indices
samples = jax.numpy.argsort(-z, axis=input.ndim - 1)[..., :num_samples, 0]
samples = jax.numpy.argsort(-z, axis=probs.ndim - 1)[..., :num_samples, 0]

return samples

Expand Down
39 changes: 39 additions & 0 deletions tests/scripts/test_backend_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,3 +2060,42 @@ def test_rand_uniform(self, backendcls, device, dtype):
assert not backend.any(output < 0)
assert not backend.any(output > 10)
assert list(output.shape) == fn_args[2:]


@pytest.mark.parametrize(
"backendcls, device, dtype", backends_with_device_dtype, ids=names
)
class TestMultinomial:
def test_multinomial_with_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.1, 0.3, 0.6]), 5]

output = fn(*fn_args, replacement=True)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_without_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.2, 0.3, 0.5]), 3]

output = fn(*fn_args, replacement=False)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_invalid_probabilities(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial

array_size = 1000
fn_args = [backend.array([0.0] * 500 + [1.0] + [0.0] * (array_size - 501)), 100]

output = fn(*fn_args, replacement=True)

assert list(output.shape) == [fn_args[1]]
assert backend.all(output == 500)