Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grpo loss #553

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
Expand Down
217 changes: 217 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_rlhf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from functools import partial

import torch
import torch.nn.functional as F


class LigerFusedLinearRLHFBase(torch.autograd.Function):
@staticmethod
def forward(
ctx,
_input,
weight,
attention_mask,
rewards,
bias=None,
loss_fn=None,
num_generations=1,
beta=0.1,
compiled=True,
use_ref_model=False,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
"""Chunked forward pass for RLHF loss computation."""
# Save for backward
ctx.beta = beta
ctx.rewards = rewards

# Initialize accumulators
loss_acc = torch.zeros((), device=_input.device)
grad_weight = torch.zeros_like(weight) # [V, H]
grad_inputs = []
grad_bias = torch.zeros_like(bias) if bias is not None else None # [V]
aggregated_metrics = []

# Create a partial function with fixed arguments
compute_loss = partial(
LigerFusedLinearRLHFBase._compute_chunk_loss,
beta=beta,
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
rlhf_loss_fn=loss_fn,
)

def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk):
"""Fused forward and backward for a chunk."""
if bias is not None:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)(
input_chunk, # arg 0
weight, # arg 1
attention_mask_chunk, # arg 2
rewards_chunk, # arg 3
ref_input_chunk, # arg 4
bias, # arg 5
)
else:
return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
input_chunk, # arg 0
weight, # arg 1
attention_mask_chunk, # arg 2
rewards_chunk, # arg 3
ref_input_chunk, # arg 4
)

def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None):
if bias is not None:
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
)
grad_bias.add_(chunk_grad_bias)
else:
(chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk
)

# Accumulate gradients and loss
grad_weight.add_(chunk_grad_weight)
grad_inputs.append(chunk_grad_input)
loss_acc.add_(chunk_loss)

# Initialize storage for metrics on first chunk
if len(aggregated_metrics) == 0:
for metric in chunk_metrics:
if metric.ndim == 0:
aggregated_metrics.append(torch.zeros((), device=metric.device))
else:
aggregated_metrics.append([])

# Accumulate metrics
for i, metric in enumerate(chunk_metrics):
if metric.ndim == 0:
aggregated_metrics[i].add_(metric)
else:
aggregated_metrics[i].append(metric)

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

# Process input in chunks
chunks = max(1, _input.shape[0] // num_generations)
_input_chunks = torch.chunk(_input, chunks=chunks, dim=0)
_attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0)
_rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0)
_ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks

for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip(
_input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
torch._dynamo.mark_dynamic(attention_mask_chunk, 1)
if ref_input_chunk is not None:
torch._dynamo.mark_dynamic(ref_input_chunk, 1)

accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk)

# Scale accumulated loss by number of chunks since we're averaging
loss_acc = loss_acc / chunks

# Combine gradients
grad_input = torch.cat(grad_inputs, dim=0)

# Save for backward
ctx.save_for_backward(grad_input, grad_weight, grad_bias)

# Finalize metrics
final_metrics = []
for metric in aggregated_metrics:
if isinstance(metric, list):
final_metrics.append(torch.cat(metric, dim=0))
else:
final_metrics.append(metric / chunks)

return loss_acc, tuple(final_metrics)

@staticmethod
def _compute_chunk_loss(
input_chunk,
weight,
attention_mask_chunk,
rewards_chunk,
ref_input_chunk=None,
bias=None,
beta=0.1,
use_ref_model=False,
ref_weight=None,
ref_bias=None,
rlhf_loss_fn=None,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias)

# Get reference log probabilities if needed
ref_log_probs = None
if use_ref_model and ref_input_chunk is not None:
with torch.no_grad():
ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias)

# Compute chunk loss and metrics using the provided loss function
chunk_loss, chunk_metrics = rlhf_loss_fn(
log_probs=log_probs,
attention_mask=attention_mask_chunk,
rewards=rewards_chunk,
ref_log_probs=ref_log_probs,
beta=beta,
)

return chunk_loss, (logits_mean, *chunk_metrics)

@staticmethod
def chunk_forward(input_chunk, weight, bias=None):
"""Forward pass computation for a single chunk."""
batch_size, seq_len, hidden_size = input_chunk.shape
input_reshaped = input_chunk.view(-1, hidden_size) # [B*T, H]

# Linear layer: [B*T, H] @ [H, V] -> [B*T, V]
logits = F.linear(input_reshaped, weight) # weight shape is [V, H]
if bias is not None:
logits = logits + bias.view(1, -1)

# Reshape to [B, T, V] and compute log_probs
logits = logits.view(batch_size, seq_len, -1)
log_probs = F.log_softmax(logits.float(), dim=-1)

# Calculate mean logits for monitoring
logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0])

return log_probs, logits, logits_mean

@staticmethod
def backward(ctx, grad_output, *grad_metrics):
"""Backward pass for RLHF loss."""
grad_input, grad_weight, grad_bias = ctx.saved_tensors
if grad_output != 1.0:
grad_input = grad_input * grad_output
grad_weight = grad_weight * grad_output
if grad_bias is not None:
grad_bias = grad_bias * grad_output

return (
grad_input,
grad_weight,
None, # grad_attention_mask
None, # grad_rewards
grad_bias,
None, # grad_loss_fn
None, # grad_chunk_size
None, # grad_beta
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_ref_input
None, # grad_ref_weight
None, # grad_ref_bias
)
159 changes: 159 additions & 0 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch

from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase


class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase):
@staticmethod
def rlhf_loss_fn(
log_probs,
attention_mask,
rewards,
ref_log_probs=None,
beta=0.1,
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
# Get chosen token probabilities
chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len)
chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(
-1
) # (batch_size, seq_len)

# Get reference model probabilities
if ref_log_probs is not None:
with torch.no_grad():
ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1)
else:
ref_token_logprobs = chosen_token_logprobs.detach()

# Compute advantages per batch entry in a grouped fashion
mean_grouped_rewards = rewards.mean() # [batch_size,]
std_grouped_rewards = rewards.std() # [batch_size,]

# Calculate advantages using the same epsilon as in GRPOTrainer
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)

# Compute policy gradient loss with importance sampling ratio
ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach())
policy_loss = -ratio * advantages.unsqueeze(1)

# Compute KL penalty
kl_div = (
torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0
)

# Combine losses
per_token_loss = policy_loss + beta * kl_div

# Apply masking and normalize
masked_loss = per_token_loss * attention_mask
seq_lengths = attention_mask.sum(dim=1, keepdim=True)
seq_lengths = torch.clamp(seq_lengths, min=1.0)
loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean()

# Calculate metrics
metrics = (
chosen_token_logprobs.mean(), # mean log prob
chosen_token_logprobs.std(), # std log prob
log_probs.mean(), # mean all log probs
((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div
)

return loss, metrics

@staticmethod
def forward(
ctx,
_input,
weight,
attention_mask,
rewards,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
beta=0.1,
compiled=True,
use_ref_model=True,
num_generations=1,
):
return LigerFusedLinearRLHFBase.forward(
ctx=ctx,
_input=_input,
weight=weight,
attention_mask=attention_mask,
loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn,
rewards=rewards,
bias=bias,
ref_input=ref_input,
ref_weight=ref_weight,
ref_bias=ref_bias,
beta=beta,
compiled=compiled,
use_ref_model=use_ref_model,
num_generations=num_generations,
)

@staticmethod
def backward(ctx, grad_output, *grad_metrics):
"""Backward pass for GRPO loss.
Args:
grad_output: Gradient of the loss (scalar)
grad_metrics: Gradients of the metrics (not used in backward computation)
"""
grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output)
return (
*grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias
None, # grad_ref_input
None, # grad_ref_weight
None, # grad_ref_bias
None, # grad_beta
None, # grad_compiled
None, # grad_use_ref_model
None, # grad_num_generations
)


class LigerFusedLinearGRPOLoss(torch.nn.Module):
"""Fused linear layer with GRPO loss."""

def __init__(
self,
beta: float = 0.1,
compiled: bool = True,
use_ref_model: bool = True,
num_generations: int = 1,
):
super().__init__()
self.beta = beta
self.compiled = compiled
self.use_ref_model = use_ref_model
self.num_generations = num_generations

def forward(
self,
lin_weight,
_input,
attention_mask,
rewards,
bias=None,
ref_input=None,
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearGRPOFunction.apply(
_input,
lin_weight,
attention_mask,
rewards,
bias,
ref_input,
ref_weight,
ref_bias,
self.beta,
self.compiled,
self.use_ref_model,
self.num_generations,
)
Loading
Loading