Skip to content

Commit

Permalink
try to make it numerical stable
Browse files Browse the repository at this point in the history
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
  • Loading branch information
hongpeng-guo committed Feb 6, 2025
1 parent 6162e88 commit 02fd778
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def liger_cross_entropy_kernel(

softmax_X = tl.exp(X_block - m) / d
# Cumulate the sum of softmax(x_i) * x_i
sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0))
sum_p_x += tl.sum(tl.where(X_offsets < n_cols, tl.math.fma(softmax_X, X_block, 0.0), 0.0))

entropy_loss = lse - sum_p_x

Expand Down Expand Up @@ -209,7 +209,8 @@ def liger_cross_entropy_kernel(
softmax_X = tl.exp(X_block - m) / d
if RETURN_ENTROPY_LOSS:
# derivatives of the entropy loss term
dX_entropy_block = tl.where(X_offsets < n_cols, softmax_X * (m - X_block + tl.log(d) - entropy_loss), 0.0)
log_softmax_X_plus_entropy = X_block - m - tl.log(d) + entropy_loss
dX_entropy_block = tl.math.fma(softmax_X, -log_softmax_X_plus_entropy, 0.0)
# Note that the weight is only applied to ce loss, not for entropy loss.
if reduction == "mean":
dX_entropy_block = dX_entropy_block / n_non_ignore
Expand Down

0 comments on commit 02fd778

Please sign in to comment.