-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: prevent zero-probability events from being sampled in multinomial function in Jax-backend #177
fix: prevent zero-probability events from being sampled in multinomial function in Jax-backend #177
Conversation
…l function in jax-backend
Codecov ReportAll modified and coverable lines are covered by tests ✅
@@ Coverage Diff @@
## main #177 +/- ##
==========================================
+ Coverage 88.71% 88.88% +0.17%
==========================================
Files 61 61
Lines 15979 15968 -11
==========================================
+ Hits 14175 14193 +18
+ Misses 1804 1775 -29
|
probs = probs / jnp.sum(probs, axis=-1, keepdims=True) | ||
|
||
# Mask zero probabilities to avoid log(0) without adding small constants | ||
logits = jnp.where(probs > 0, jnp.log(probs), -jnp.inf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using jnp.where
in order to avoid numerical instabilities, I tried this
logits = jax.numpy.log(probs)
I tested the edge cases in simple ways and did not encounter any problems as far as I observed. I think we could also write it this way, as it is simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the log
function automatically returns -inf
when the input is zero.
2f33507
to
d51dc93
Compare
Seems like we have a formatting problem, but it otherwise looks great. If you make sure the pre-commit procedure passes, we can proceed with the merge. |
Thank you @SinanGncgl 🚀 |
…l function in Jax-backend (synnada-ai#177) Co-authored-by: Mehmet Ozan Kabak <ozankabak@gmail.com>
Description
This PR fixes the issue #7 where zero-probability events could still be sampled due to the use of a small constant (1e-37) in the logits calculation. The fix replaces this constant with -inf to ensure such events are never selected and eliminate zero prob event.
What is Changed
Include the changes introduced in this PR.
Checklist: