Skip to content

Commit

Permalink
fix typo input->probs
Browse files Browse the repository at this point in the history
  • Loading branch information
SinanGncgl committed Feb 1, 2025
1 parent b3dae6a commit b375b40
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,11 @@ def multinomial(

# For without replacement, use Gumbel-max trick
# This is much faster than using choice
z = jax.random.gumbel(prng_key, shape=input.shape + (num_samples,))
z = jax.random.gumbel(prng_key, shape=probs.shape + (num_samples,))

Check warning on line 535 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L535

Added line #L535 was not covered by tests
# Add log probabilities for Gumbel-max trick,
z = z + logits[..., None]
# Get top k indices
samples = jax.numpy.argsort(-z, axis=input.ndim - 1)[..., :num_samples, 0]
samples = jax.numpy.argsort(-z, axis=probs.ndim - 1)[..., :num_samples, 0]

Check warning on line 539 in mithril/backends/with_autograd/jax_backend/backend.py

View check run for this annotation

Codecov / codecov/patch

mithril/backends/with_autograd/jax_backend/backend.py#L539

Added line #L539 was not covered by tests

return samples

Expand Down

0 comments on commit b375b40

Please sign in to comment.