Skip to content

Commit

Permalink
change a new way calculating entropy
Browse files Browse the repository at this point in the history
Signed-off-by: Hongpeng Guo <hpguo@anyscale.com>
  • Loading branch information
hongpeng-guo committed Feb 6, 2025
1 parent ced5709 commit c1d36e6
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def liger_cross_entropy_kernel(

# 3.5 Calculate the entropy loss
if RETURN_ENTROPY_LOSS:
sum_p_x = 0.0 # sum of softmax(x_i) * x_i
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
Expand All @@ -164,10 +165,10 @@ def liger_cross_entropy_kernel(
X_block = softcap * intermediate

softmax_X = tl.exp(X_block - m) / d
# Mask for valid columns and non-zero softmax
valid_mask = X_offsets < n_cols
entropy_term = tl.where(valid_mask, -softmax_X * tl.log(softmax_X), 0.0)
entropy_loss += tl.sum(entropy_term)
# Cumulate the sum of softmax(x_i) * x_i
sum_p_x += tl.sum(tl.where(X_offsets < n_cols, softmax_X * X_block, 0.0))

entropy_loss = lse - sum_p_x

# 4. [Online Softmax] Second pass: compute gradients
# For 'mean' reduction, gradients are normalized by number of non-ignored elements (N)
Expand Down Expand Up @@ -203,11 +204,9 @@ def liger_cross_entropy_kernel(
mask=X_offsets < n_cols,
other=0.0,
)
# valid mask for the entropy loss
valid_mask = X_offsets < n_cols

# Calculate the softmax of the input
softmax_X = tl.exp(X_block - m) / d

if RETURN_ENTROPY_LOSS:
# derivatives of the entropy loss term
dX_entropy_block = softmax_X * (-tl.log(softmax_X) - entropy_loss)
Expand Down

0 comments on commit c1d36e6

Please sign in to comment.