Skip to content

Commit

Permalink
Merge branch 'main' into feature/deepseekv2
Browse files Browse the repository at this point in the history
  • Loading branch information
lancerts authored Jan 21, 2025
2 parents b6287f1 + 2ea3cfb commit 8e71b13
Show file tree
Hide file tree
Showing 17 changed files with 282 additions and 110 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@

<details>
<summary>Latest News 🔥</summary>
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!

- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
Expand Down Expand Up @@ -253,7 +253,7 @@ loss.backward()
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| DeepSeekv2 | `liger_kernel.transformers.apply_liger_kernel_to_deepseek_v2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

Expand Down Expand Up @@ -309,11 +309,11 @@ loss.backward()
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/License.md)

## Sponsorship and Collaboration

- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
Expand Down
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_cpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
8 changes: 4 additions & 4 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import torch
import triton

from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -21,7 +19,8 @@


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down Expand Up @@ -70,7 +69,8 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down
33 changes: 14 additions & 19 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,21 +36,18 @@ def bench_memory_fused_linear_orpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

def full():
y = fwd()
Expand All @@ -72,7 +69,8 @@ def full():
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,21 +80,18 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
18 changes: 16 additions & 2 deletions benchmark/scripts/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import triton

from test.utils import transformers_version_dispatch
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from utils import QUANTILES
Expand Down Expand Up @@ -30,7 +32,13 @@ def bench_speed_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down Expand Up @@ -105,7 +113,13 @@ def bench_memory_rope(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutpu
seq_len = extra_benchmark_config["seq_len"] if "seq_len" in extra_benchmark_config else input.x

head_dim = hidden_size // num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)
rotary_emb = transformers_version_dispatch(
"4.48.0",
LlamaRotaryEmbedding,
LlamaRotaryEmbedding,
before_kwargs={"dim": head_dim, "device": device},
after_kwargs={"config": LlamaConfig(num_kv_heads=num_kv_heads, head_dim=head_dim), "device": device},
)
q = torch.randn(
(1, seq_len, num_q_heads, head_dim),
device=device,
Expand Down
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_simpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_simpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
Loading

0 comments on commit 8e71b13

Please sign in to comment.