Skip to content

Commit

Permalink
add batch_norm op with test and benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghl committed Feb 7, 2025
1 parent bcde83c commit 32507fb
Show file tree
Hide file tree
Showing 6 changed files with 538 additions and 0 deletions.
30 changes: 30 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -769,3 +769,33 @@ distill_jsd_loss,torch,full,memory,MB,BT,B x T,1024,16174.0390625,16174.0390625,
distill_jsd_loss,torch,full,memory,MB,BT,B x T,2048,23713.05078125,23713.05078125,23713.05078125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,4096,38791.07421875,38791.07421875,38791.07421875,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
distill_jsd_loss,torch,full,memory,MB,BT,B x T,8192,68947.1015625,68947.1015625,68947.1015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": false, ""weight_hard_loss"": 0.5, ""weight_soft_loss"": 0.5, ""ignore_index"": -100}",NVIDIA H100 80GB HBM3,2024-12-03 08:01:32,0.4.2
batch_norm,liger,forward,speed,ms,N,hidden size,1024,0.13689599931240082,0.13616639375686646,0.13795199990272522,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2
batch_norm,liger,forward,speed,ms,N,hidden size,2048,0.26447999477386475,0.26284798979759216,0.2656959891319275,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2
batch_norm,liger,forward,speed,ms,N,hidden size,4096,0.525056004524231,0.5232831835746765,0.5266559720039368,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2
batch_norm,liger,forward,speed,ms,N,hidden size,8192,1.05131196975708,1.0489856004714966,1.0533759593963623,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2
batch_norm,liger,forward,speed,ms,N,hidden size,16384,2.13972806930542,2.1362624168395996,2.143014430999756,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:43,0.5.2
batch_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.041471999138593674,0.0398080013692379,0.042688000947237015,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2
batch_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.06825599819421768,0.06672000139951706,0.0695360004901886,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2
batch_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.1191679984331131,0.11868800222873688,0.11961600184440613,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2
batch_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.21347199380397797,0.21296000480651855,0.21398399770259857,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2
batch_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.4029119908809662,0.4023999869823456,0.40348801016807556,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:46,0.5.2
batch_norm,liger,full,speed,ms,N,hidden size,1024,0.3394879996776581,0.3375680148601532,0.3413119912147522,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2
batch_norm,liger,full,speed,ms,N,hidden size,2048,0.6499840021133423,0.6464319825172424,0.6534016132354736,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2
batch_norm,liger,full,speed,ms,N,hidden size,4096,1.2944639921188354,1.291468858718872,1.297875165939331,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2
batch_norm,liger,full,speed,ms,N,hidden size,8192,2.5837440490722656,2.579263925552368,2.5880000591278076,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2
batch_norm,liger,full,speed,ms,N,hidden size,16384,5.309120178222656,5.301023960113525,5.314540863037109,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:50,0.5.2
batch_norm,huggingface,full,speed,ms,N,hidden size,1024,0.08718399703502655,0.08614400029182434,0.08816000074148178,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,speed,ms,N,hidden size,2048,0.14828799664974213,0.14732800424098969,0.14927999675273895,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,speed,ms,N,hidden size,4096,0.25726401805877686,0.25622400641441345,0.2583935856819153,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,speed,ms,N,hidden size,8192,0.4660159945487976,0.46483200788497925,0.4671808183193207,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,speed,ms,N,hidden size,16384,0.880128026008606,0.8787840008735657,0.8814719915390015,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,liger,full,memory,MB,N,hidden size,1024,80.04736328125,80.04736328125,80.04736328125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,liger,full,memory,MB,N,hidden size,2048,160.09423828125,160.09423828125,160.09423828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,liger,full,memory,MB,N,hidden size,4096,320.18798828125,320.18798828125,320.18798828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,liger,full,memory,MB,N,hidden size,8192,640.37548828125,640.37548828125,640.37548828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,liger,full,memory,MB,N,hidden size,16384,1280.75048828125,1280.75048828125,1280.75048828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,memory,MB,N,hidden size,1024,80.05517578125,80.05517578125,80.05517578125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,memory,MB,N,hidden size,2048,160.10986328125,160.10986328125,160.10986328125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,memory,MB,N,hidden size,4096,320.21923828125,320.21923828125,320.21923828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,memory,MB,N,hidden size,8192,640.43798828125,640.43798828125,640.43798828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
batch_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.87548828125,1280.87548828125,1280.87548828125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA H100 80GB HBM3,2025-02-07 19:40:52,0.5.2
125 changes: 125 additions & 0 deletions benchmark/scripts/benchmark_batch_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import torch
import triton

from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.transformers.batch_norm import LigerBatchNorm
from liger_kernel.utils import infer_device

device = infer_device()


def bench_speed_batch_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
mode = input.kernel_operation_mode
extra_benchmark_config = input.extra_benchmark_config
M = extra_benchmark_config["M"]
eps = extra_benchmark_config["eps"]
dtype = extra_benchmark_config["dtype"]

x_shape = (M, N)
triton_bn = LigerBatchNorm(hidden_size=N).to(device)
torch_bn = torch.nn.BatchNorm1d(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

def y_fwd():
if provider == "liger":
return triton_bn(x)
if provider == "huggingface":
return torch_bn(x)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
elif mode == "backward":
y = y_fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(dy, retain_graph=True),
quantiles=QUANTILES,
grad_to_none=[x],
rep=500,
)
elif mode == "full":

def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


def bench_memory_batch_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
N = input.x
provider = input.kernel_provider
dtype = input.extra_benchmark_config["dtype"]
M = input.extra_benchmark_config["M"]
eps = input.extra_benchmark_config["eps"]

x_shape = (M, N)

triton_bn = LigerBatchNorm(hidden_size=N).to(device)
torch_bn = torch.nn.BatchNorm1d(N, eps=eps).to(device)

x = torch.randn(x_shape, dtype=dtype, device=device)
dy = torch.randn_like(x)
x.requires_grad_(True)

def y_fwd():
if provider == "liger":
return triton_bn(x)
if provider == "huggingface":
return torch_bn(x)

def full():
y = y_fwd()
y.backward(dy, retain_graph=True)

mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "batch_norm",
"x_name": "N",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 15)], # Range of hidden size values
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 4096, "dtype": torch.float32, "eps": 1e-6}],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_batch_norm,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_batch_norm,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
Loading

0 comments on commit 32507fb

Please sign in to comment.