Skip to content

Commit

Permalink
simplify log calculation and add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SinanGncgl committed Feb 3, 2025
1 parent b375b40 commit d51dc93
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
4 changes: 1 addition & 3 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,7 @@ def multinomial(
prng_key = self._get_prng_key(key)
probs = jnp.asarray(probs)
probs = probs / jnp.sum(probs, axis=-1, keepdims=True)

# Mask zero probabilities to avoid log(0) without adding small constants
logits = jnp.where(probs > 0, jnp.log(probs), -jnp.inf)
logits = jnp.log(probs)

if replacement:
# Use categorical directly - much faster than choice
Expand Down
39 changes: 39 additions & 0 deletions tests/scripts/test_backend_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,3 +2060,42 @@ def test_rand_uniform(self, backendcls, device, dtype):
assert not backend.any(output < 0)
assert not backend.any(output > 10)
assert list(output.shape) == fn_args[2:]


@pytest.mark.parametrize(
"backendcls, device, dtype", backends_with_device_dtype, ids=names
)
class TestMultinomial:
def test_multinomial_with_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.1, 0.3, 0.6]), 5]

output = fn(*fn_args, replacement=True)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_without_replacement(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial
fn_args = [backend.array([0.2, 0.3, 0.5]), 3]

output = fn(*fn_args, replacement=False)

assert not backend.any(output < 0)
assert not backend.any(output >= len(fn_args[0]))
assert list(output.shape) == [fn_args[1]]

def test_multinomial_invalid_probabilities(self, backendcls, device, dtype):
backend = backendcls(device=device, dtype=dtype)
fn = backend.multinomial

array_size = 1000
fn_args = [backend.array([0.0] * 500 + [1.0] + [0.0] * (array_size - 501)), 100] # Sample 100 times

output = fn(*fn_args, replacement=True)

assert list(output.shape) == [fn_args[1]]
assert backend.all(output == 500)

0 comments on commit d51dc93

Please sign in to comment.