Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent zero-probability events from being sampled in multinomial function in Jax-backend #177

Conversation

SinanGncgl
Copy link
Contributor

@SinanGncgl SinanGncgl commented Feb 1, 2025

Description

This PR fixes the issue #7 where zero-probability events could still be sampled due to the use of a small constant (1e-37) in the logits calculation. The fix replaces this constant with -inf to ensure such events are never selected and eliminate zero prob event.

What is Changed

Include the changes introduced in this PR.

  • Updated the multinomial implementation to handle zero-probability outcomes correctly.
  • Replaced 1e-37 with -inf using jnp.where for stable log calculations.

Checklist:

  • Tests that cover the code added.
  • Corresponding changes documented.
  • All tests passed.
  • The code linted and styled (pre-commit run --all-files has passed).

Copy link

codecov bot commented Feb 1, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 88.88%. Comparing base (5a07669) to head (3b14af7).
Report is 2 commits behind head on main.

@@            Coverage Diff             @@
##             main     #177      +/-   ##
==========================================
+ Coverage   88.71%   88.88%   +0.17%     
==========================================
  Files          61       61              
  Lines       15979    15968      -11     
==========================================
+ Hits        14175    14193      +18     
+ Misses       1804     1775      -29     
Files with missing lines Coverage Δ
...hril/backends/with_autograd/jax_backend/backend.py 88.33% <100.00%> (+3.56%) ⬆️

... and 7 files with indirect coverage changes

probs = probs / jnp.sum(probs, axis=-1, keepdims=True)

# Mask zero probabilities to avoid log(0) without adding small constants
logits = jnp.where(probs > 0, jnp.log(probs), -jnp.inf)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using jnp.where in order to avoid numerical instabilities, I tried this

logits = jax.numpy.log(probs)

I tested the edge cases in simple ways and did not encounter any problems as far as I observed. I think we could also write it this way, as it is simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the log function automatically returns -inf when the input is zero.

@github-actions github-actions bot added the tests label Feb 3, 2025
@SinanGncgl SinanGncgl force-pushed the fix-jax-backend-zero-probability-sampling branch from 2f33507 to d51dc93 Compare February 3, 2025 18:15
@ozankabak
Copy link
Contributor

Seems like we have a formatting problem, but it otherwise looks great. If you make sure the pre-commit procedure passes, we can proceed with the merge.

@ozankabak ozankabak merged commit e34d29a into synnada-ai:main Feb 4, 2025
9 checks passed
@ozankabak
Copy link
Contributor

Thank you @SinanGncgl 🚀

@SinanGncgl SinanGncgl deleted the fix-jax-backend-zero-probability-sampling branch February 4, 2025 19:09
mehmetozsoy-synnada pushed a commit to mehmetozsoy-synnada/mithril that referenced this pull request Feb 5, 2025
…l function in Jax-backend (synnada-ai#177)

Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants