Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent zero-probability events from being sampled in multinomial function in Jax-backend #177

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,29 +513,30 @@
replacement: whether to sample with replacement
"""
prng_key = self._get_prng_key(key)
input = jax.numpy.asarray(probs)
input = input / jax.numpy.sum(input, axis=-1, keepdims=True)
batch_size = input.shape[:-1]
logits = jax.numpy.log(jax.numpy.maximum(input, 1e-37))
probs = jnp.asarray(probs)
probs = probs / jnp.sum(probs, axis=-1, keepdims=True)

Check warning on line 517 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L516-L517

Added lines #L516 - L517 were not covered by tests

# Mask zero probabilities to avoid log(0) without adding small constants
logits = jnp.where(probs > 0, jnp.log(probs), -jnp.inf)

Check warning on line 520 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L520

Added line #L520 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using jnp.where in order to avoid numerical instabilities, I tried this

logits = jax.numpy.log(probs)

I tested the edge cases in simple ways and did not encounter any problems as far as I observed. I think we could also write it this way, as it is simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the log function automatically returns -inf when the input is zero.


if replacement:
# Use categorical directly - much faster than choice
samples = jax.random.categorical(
prng_key,
logits, # avoid log(0)
shape=batch_size + (num_samples,),
shape=probs.shape[:-1] + (num_samples,),
)
else:
# TODO: This algorithm is not efficient for small num_samples
# consider more efficient algorithm

# For without replacement, use Gumbel-max trick
# This is much faster than using choice
z = jax.random.gumbel(prng_key, shape=input.shape + (num_samples,))
z = jax.random.gumbel(prng_key, shape=probs.shape + (num_samples,))

Check warning on line 535 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L535

Added line #L535 was not covered by tests
# Add log probabilities for Gumbel-max trick,
z = z + logits[..., None]
# Get top k indices
samples = jax.numpy.argsort(-z, axis=input.ndim - 1)[..., :num_samples, 0]
samples = jax.numpy.argsort(-z, axis=probs.ndim - 1)[..., :num_samples, 0]

Check warning on line 539 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L539

Added line #L539 was not covered by tests

return samples

Expand Down