diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index 4f76ab79d..06c49d6b3 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -1,5 +1,6 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/fused_linear_rlhf.py b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py new file mode 100644 index 000000000..8976d9dbd --- /dev/null +++ b/src/liger_kernel/chunked_loss/fused_linear_rlhf.py @@ -0,0 +1,217 @@ +from functools import partial + +import torch +import torch.nn.functional as F + + +class LigerFusedLinearRLHFBase(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + attention_mask, + rewards, + bias=None, + loss_fn=None, + num_generations=1, + beta=0.1, + compiled=True, + use_ref_model=False, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + """Chunked forward pass for RLHF loss computation.""" + # Save for backward + ctx.beta = beta + ctx.rewards = rewards + + # Initialize accumulators + loss_acc = torch.zeros((), device=_input.device) + grad_weight = torch.zeros_like(weight) # [V, H] + grad_inputs = [] + grad_bias = torch.zeros_like(bias) if bias is not None else None # [V] + aggregated_metrics = [] + + # Create a partial function with fixed arguments + compute_loss = partial( + LigerFusedLinearRLHFBase._compute_chunk_loss, + beta=beta, + use_ref_model=use_ref_model, + ref_weight=ref_weight, + ref_bias=ref_bias, + rlhf_loss_fn=loss_fn, + ) + + def fused_fwd_bwd(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk): + """Fused forward and backward for a chunk.""" + if bias is not None: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1, 5), has_aux=True)( + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 + bias, # arg 5 + ) + else: + return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( + input_chunk, # arg 0 + weight, # arg 1 + attention_mask_chunk, # arg 2 + rewards_chunk, # arg 3 + ref_input_chunk, # arg 4 + ) + + def accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk=None): + if bias is not None: + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk + ) + grad_bias.add_(chunk_grad_bias) + else: + (chunk_grad_input, chunk_grad_weight), (chunk_loss, chunk_metrics) = fused_fwd_bwd( + input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk + ) + + # Accumulate gradients and loss + grad_weight.add_(chunk_grad_weight) + grad_inputs.append(chunk_grad_input) + loss_acc.add_(chunk_loss) + + # Initialize storage for metrics on first chunk + if len(aggregated_metrics) == 0: + for metric in chunk_metrics: + if metric.ndim == 0: + aggregated_metrics.append(torch.zeros((), device=metric.device)) + else: + aggregated_metrics.append([]) + + # Accumulate metrics + for i, metric in enumerate(chunk_metrics): + if metric.ndim == 0: + aggregated_metrics[i].add_(metric) + else: + aggregated_metrics[i].append(metric) + + if compiled: + accumulate_chunk = torch.compile(accumulate_chunk) + + # Process input in chunks + chunks = max(1, _input.shape[0] // num_generations) + _input_chunks = torch.chunk(_input, chunks=chunks, dim=0) + _attention_mask_chunks = torch.chunk(attention_mask, chunks=chunks, dim=0) + _rewards_chunks = torch.chunk(rewards, chunks=chunks, dim=0) + _ref_input_chunks = torch.chunk(ref_input, chunks=chunks, dim=0) if use_ref_model else [None] * chunks + + for input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk in zip( + _input_chunks, _attention_mask_chunks, _rewards_chunks, _ref_input_chunks + ): + # Mark dynamic dimensions + torch._dynamo.mark_dynamic(input_chunk, 1) + torch._dynamo.mark_dynamic(attention_mask_chunk, 1) + if ref_input_chunk is not None: + torch._dynamo.mark_dynamic(ref_input_chunk, 1) + + accumulate_chunk(input_chunk, attention_mask_chunk, rewards_chunk, ref_input_chunk) + + # Scale accumulated loss by number of chunks since we're averaging + loss_acc = loss_acc / chunks + + # Combine gradients + grad_input = torch.cat(grad_inputs, dim=0) + + # Save for backward + ctx.save_for_backward(grad_input, grad_weight, grad_bias) + + # Finalize metrics + final_metrics = [] + for metric in aggregated_metrics: + if isinstance(metric, list): + final_metrics.append(torch.cat(metric, dim=0)) + else: + final_metrics.append(metric / chunks) + + return loss_acc, tuple(final_metrics) + + @staticmethod + def _compute_chunk_loss( + input_chunk, + weight, + attention_mask_chunk, + rewards_chunk, + ref_input_chunk=None, + bias=None, + beta=0.1, + use_ref_model=False, + ref_weight=None, + ref_bias=None, + rlhf_loss_fn=None, + ): + """Compute loss for a single chunk.""" + # Get policy log probabilities using chunk_forward + log_probs, _, logits_mean = LigerFusedLinearRLHFBase.chunk_forward(input_chunk, weight, bias=bias) + + # Get reference log probabilities if needed + ref_log_probs = None + if use_ref_model and ref_input_chunk is not None: + with torch.no_grad(): + ref_log_probs, _, _ = LigerFusedLinearRLHFBase.chunk_forward(ref_input_chunk, ref_weight, bias=ref_bias) + + # Compute chunk loss and metrics using the provided loss function + chunk_loss, chunk_metrics = rlhf_loss_fn( + log_probs=log_probs, + attention_mask=attention_mask_chunk, + rewards=rewards_chunk, + ref_log_probs=ref_log_probs, + beta=beta, + ) + + return chunk_loss, (logits_mean, *chunk_metrics) + + @staticmethod + def chunk_forward(input_chunk, weight, bias=None): + """Forward pass computation for a single chunk.""" + batch_size, seq_len, hidden_size = input_chunk.shape + input_reshaped = input_chunk.view(-1, hidden_size) # [B*T, H] + + # Linear layer: [B*T, H] @ [H, V] -> [B*T, V] + logits = F.linear(input_reshaped, weight) # weight shape is [V, H] + if bias is not None: + logits = logits + bias.view(1, -1) + + # Reshape to [B, T, V] and compute log_probs + logits = logits.view(batch_size, seq_len, -1) + log_probs = F.log_softmax(logits.float(), dim=-1) + + # Calculate mean logits for monitoring + logits_mean = logits.sum() / (batch_size * seq_len * weight.shape[0]) + + return log_probs, logits, logits_mean + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for RLHF loss.""" + grad_input, grad_weight, grad_bias = ctx.saved_tensors + if grad_output != 1.0: + grad_input = grad_input * grad_output + grad_weight = grad_weight * grad_output + if grad_bias is not None: + grad_bias = grad_bias * grad_output + + return ( + grad_input, + grad_weight, + None, # grad_attention_mask + None, # grad_rewards + grad_bias, + None, # grad_loss_fn + None, # grad_chunk_size + None, # grad_beta + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py new file mode 100644 index 000000000..593baf613 --- /dev/null +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -0,0 +1,159 @@ +import torch + +from liger_kernel.chunked_loss.fused_linear_rlhf import LigerFusedLinearRLHFBase + + +class LigerFusedLinearGRPOFunction(LigerFusedLinearRLHFBase): + @staticmethod + def rlhf_loss_fn( + log_probs, + attention_mask, + rewards, + ref_log_probs=None, + beta=0.1, + **kwargs, + ): + """GRPO Loss Function matching GRPOTrainer implementation.""" + # Get chosen token probabilities + chosen_tokens = log_probs.argmax(dim=-1) # (batch_size, seq_len) + chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze( + -1 + ) # (batch_size, seq_len) + + # Get reference model probabilities + if ref_log_probs is not None: + with torch.no_grad(): + ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) + else: + ref_token_logprobs = chosen_token_logprobs.detach() + + # Compute advantages per batch entry in a grouped fashion + mean_grouped_rewards = rewards.mean() # [batch_size,] + std_grouped_rewards = rewards.std() # [batch_size,] + + # Calculate advantages using the same epsilon as in GRPOTrainer + advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # Compute policy gradient loss with importance sampling ratio + ratio = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) + policy_loss = -ratio * advantages.unsqueeze(1) + + # Compute KL penalty + kl_div = ( + torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 + ) + + # Combine losses + per_token_loss = policy_loss + beta * kl_div + + # Apply masking and normalize + masked_loss = per_token_loss * attention_mask + seq_lengths = attention_mask.sum(dim=1, keepdim=True) + seq_lengths = torch.clamp(seq_lengths, min=1.0) + loss = (masked_loss.sum(dim=1) / seq_lengths.squeeze(-1)).mean() + + # Calculate metrics + metrics = ( + chosen_token_logprobs.mean(), # mean log prob + chosen_token_logprobs.std(), # std log prob + log_probs.mean(), # mean all log probs + ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), # mean KL div + ) + + return loss, metrics + + @staticmethod + def forward( + ctx, + _input, + weight, + attention_mask, + rewards, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + beta=0.1, + compiled=True, + use_ref_model=True, + num_generations=1, + ): + return LigerFusedLinearRLHFBase.forward( + ctx=ctx, + _input=_input, + weight=weight, + attention_mask=attention_mask, + loss_fn=LigerFusedLinearGRPOFunction.rlhf_loss_fn, + rewards=rewards, + bias=bias, + ref_input=ref_input, + ref_weight=ref_weight, + ref_bias=ref_bias, + beta=beta, + compiled=compiled, + use_ref_model=use_ref_model, + num_generations=num_generations, + ) + + @staticmethod + def backward(ctx, grad_output, *grad_metrics): + """Backward pass for GRPO loss. + + Args: + grad_output: Gradient of the loss (scalar) + grad_metrics: Gradients of the metrics (not used in backward computation) + """ + grads = LigerFusedLinearRLHFBase.backward(ctx, grad_output) + return ( + *grads[:5], # grad_input, grad_weight, grad_attention_mask, grad_rewards, grad_bias + None, # grad_ref_input + None, # grad_ref_weight + None, # grad_ref_bias + None, # grad_beta + None, # grad_compiled + None, # grad_use_ref_model + None, # grad_num_generations + ) + + +class LigerFusedLinearGRPOLoss(torch.nn.Module): + """Fused linear layer with GRPO loss.""" + + def __init__( + self, + beta: float = 0.1, + compiled: bool = True, + use_ref_model: bool = True, + num_generations: int = 1, + ): + super().__init__() + self.beta = beta + self.compiled = compiled + self.use_ref_model = use_ref_model + self.num_generations = num_generations + + def forward( + self, + lin_weight, + _input, + attention_mask, + rewards, + bias=None, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + return LigerFusedLinearGRPOFunction.apply( + _input, + lin_weight, + attention_mask, + rewards, + bias, + ref_input, + ref_weight, + ref_bias, + self.beta, + self.compiled, + self.use_ref_model, + self.num_generations, + ) diff --git a/test/chunked_loss/test_grpo_loss.py b/test/chunked_loss/test_grpo_loss.py new file mode 100644 index 000000000..afa4b495b --- /dev/null +++ b/test/chunked_loss/test_grpo_loss.py @@ -0,0 +1,274 @@ +import pytest +import torch +import torch.nn.functional as F + +from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOFunction +from liger_kernel.utils import infer_device +from test.utils import assert_verbose_allclose +from test.utils import set_seed + +device = infer_device() + +# set random seed globally +set_seed() + + +class TorchLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + num_generations: int = 4, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.beta = beta + self.num_generations = num_generations + + def forward( + self, + x, # Shape: [batch_size*num_generations, seq_len, hidden_size] + attention_mask, # Shape: [batch_size*num_generations, seq_len] + rewards, # Shape: [batch_size*num_generations,] + ref_input=None, # Shape: [batch_size*num_generations, seq_len, hidden_size] + ref_weight=None, + ref_bias=None, + ): + # Forward pass through linear layer + batch_size = x.shape[0] // self.num_generations # Get true batch size + seq_len = x.shape[1] + hidden_size = x.shape[2] + + input_reshaped = x.view(-1, hidden_size) + logits = (input_reshaped @ self.lin.weight.t()).view(batch_size * self.num_generations, seq_len, -1) + if self.lin.bias is not None: + logits = logits + self.lin.bias + + # Get log probabilities + log_probs = F.log_softmax(logits, dim=-1) + + # Get chosen token probabilities + chosen_tokens = log_probs.argmax(dim=-1) + chosen_token_logprobs = log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) + + # Get reference model probabilities + if ref_input is not None and ref_weight is not None: + with torch.no_grad(): + ref_input_reshaped = ref_input.view(-1, ref_input.size(-1)) + ref_logits = (ref_input_reshaped @ ref_weight.t()).view(batch_size * self.num_generations, seq_len, -1) + if ref_bias is not None: + ref_logits = ref_logits + ref_bias + ref_log_probs = F.log_softmax(ref_logits, dim=-1) + ref_token_logprobs = ref_log_probs.gather(dim=-1, index=chosen_tokens.unsqueeze(-1)).squeeze(-1) + else: + ref_token_logprobs = chosen_token_logprobs.detach() + + # Compute KL divergence between model and reference model + kl_div = ( + torch.exp(ref_token_logprobs - chosen_token_logprobs) - (ref_token_logprobs - chosen_token_logprobs) - 1.0 + ) + + # Compute advantages per batch entry in a grouped fashion + # rewards shape: [batch_size, num_generations] + mean_grouped_rewards = rewards.view(batch_size, self.num_generations).mean( + dim=1, keepdim=True + ) # [batch_size, 1] + std_grouped_rewards = rewards.view(batch_size, self.num_generations).std(dim=1, keepdim=True) # [batch_size, 1] + + # Expand means and stds to match the number of generations + mean_grouped_rewards = mean_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + std_grouped_rewards = std_grouped_rewards.expand(-1, self.num_generations).reshape( + -1 + ) # [batch_size * num_generations] + + # Calculate advantages using the same epsilon as in GRPOTrainer + rewards_flat = rewards.view(-1) # [batch_size * num_generations] + advantages = (rewards_flat - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + # Compute policy gradient loss with importance sampling ratio + per_token_loss = torch.exp(chosen_token_logprobs - chosen_token_logprobs.detach()) * advantages.unsqueeze(1) + per_token_loss = -(per_token_loss - self.beta * kl_div) + + # Apply masking and normalize + loss = ((per_token_loss * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean() + + # Compute metrics + metrics = ( + logits.mean(), + chosen_token_logprobs.mean(), + chosen_token_logprobs.std(), + log_probs.mean(), + ((kl_div * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)).mean(), + ) + + return loss, metrics + + +class LigerLMHeadGRPO(torch.nn.Module): + def __init__( + self, + H: int, + V: int, + dtype: torch.dtype, + bias: bool = False, + beta: float = 0.1, + num_generations: int = 4, + ): + super().__init__() + self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) + self.grpo_loss = LigerFusedLinearGRPOFunction.apply + self.beta = beta + self.num_generations = num_generations + + def forward( + self, + x, + attention_mask, + rewards, + ref_input=None, + ref_weight=None, + ref_bias=None, + ): + # Pass only the arguments defined in LigerFusedLinearGRPOFunction.forward() + return self.grpo_loss( + x, # _input + self.lin.weight, # weight + attention_mask, # attention_mask + rewards, # rewards + self.lin.bias, # bias + ref_input, # ref_input + ref_weight, # ref_weight + ref_bias, # ref_bias + self.beta, # beta + True, # compiled + ref_input is not None, # use_ref_model + self.num_generations, # num_generations + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (8, 128, 1024, 4096), + (3, 47, 31, 123), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-2, 5e-2), + (1.0, torch.float32, 5e-2, 5e-2), + ], +) +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("beta", [0.1, 0.9]) +def test_correctness( + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + beta, +): + num_generations = 4 # Fixed number of generations for testing + torch_lm_head_grpo = TorchLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + num_generations=num_generations, + ) + liger_lm_head_grpo = LigerLMHeadGRPO( + H=H, + V=V, + dtype=dtype, + bias=bias, + beta=beta, + num_generations=num_generations, + ) + + # Initialize weights + torch_lm_head_grpo.lin.weight.data = liger_lm_head_grpo.lin.weight.data = torch.randn( + V, H, device=device, dtype=dtype + ) + if bias: + torch_lm_head_grpo.lin.bias.data = liger_lm_head_grpo.lin.bias.data = torch.randn(V, device=device, dtype=dtype) + + # Create inputs with shape [B*num_generations, T, H] + _input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar + input1 = _input.detach().clone().requires_grad_(True) + input2 = _input.detach().clone().requires_grad_(True) + + # Create attention mask with random padding [B*num_generations, T] + attention_mask = torch.ones(B * num_generations, T, device=device) + num_elements_to_mask = torch.randint(1, B * num_generations * T // 2, (1,)).item() + mask_indices = torch.randperm(B * num_generations * T)[:num_elements_to_mask] + attention_mask.view(-1)[mask_indices] = 0 + + # Create rewards with shape [B, num_generations] + rewards = torch.rand(B * num_generations, device=device, dtype=dtype) + + # Create reference inputs (optional) with shape [B*num_generations, T, H] + ref_input = torch.randn(B * num_generations, T, H, device=device, dtype=dtype) * scalar + ref_weight = torch.randn(V, H, device=device, dtype=dtype) + ref_bias_weight = torch.randn(V, device=device, dtype=dtype) if ref_bias else None + + # Forward pass with reference model + loss1, aux1 = torch_lm_head_grpo( + input1, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias_weight + ) + loss2, aux2 = liger_lm_head_grpo( + input2, attention_mask, rewards, ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias_weight + ) + + # Check losses match + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match + assert len(aux1) == len(aux2) + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol) + + # Backward pass + loss1.backward() + loss2.backward() + + # Check gradients match + assert_verbose_allclose(input1.grad, input2.grad, atol=atol, rtol=rtol) + assert_verbose_allclose( + torch_lm_head_grpo.lin.weight.grad, + liger_lm_head_grpo.lin.weight.grad, + atol=atol, + rtol=rtol, + ) + if bias: + assert_verbose_allclose( + torch_lm_head_grpo.lin.bias.grad, + liger_lm_head_grpo.lin.bias.grad, + atol=atol, + rtol=rtol, + ) + + # Test without reference model + loss1, aux1 = torch_lm_head_grpo(input1, attention_mask, rewards) + loss2, aux2 = liger_lm_head_grpo(input2, attention_mask, rewards) + + # Check losses match (without reference model) + assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) + + # Check metrics match (without reference model) + assert len(aux1) == len(aux2) + for metric1, metric2 in zip(aux1, aux2): + assert_verbose_allclose(metric1, metric2, atol=atol, rtol=rtol)