Skip to content

Commit

Permalink
make deriv stable
Browse files Browse the repository at this point in the history
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
  • Loading branch information
hongpeng-guo committed Feb 6, 2025
1 parent c1d36e6 commit b1053a3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def liger_cross_entropy_kernel(
softmax_X = tl.exp(X_block - m) / d
if RETURN_ENTROPY_LOSS:
# derivatives of the entropy loss term
dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss)
dX_entropy_block = softmax_X * (m - X_block + tl.log(d) - entropy_loss)
# Note that the weight is only applied to ce loss, not for entropy loss.
if reduction == "mean":
dX_entropy_block = dX_entropy_block / n_non_ignore
Expand Down

0 comments on commit b1053a3

Please sign in to comment.