Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert Fix #2 for Minimal All Reduce PR #18757

Merged
merged 20 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions models/demos/llama3/tests/test_ccl_async_perf_TG_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: © 2025 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn

from models.perf.benchmarking_utils import BenchmarkData, BenchmarkProfiler
from models.perf.device_perf_utils import run_device_perf_detailed


@pytest.mark.parametrize(
"ag_type, warmup_iters, perf_target_us",
[
("sdpa", 10, 11),
("binary_mult", 10, 12),
("layernorm", 10, 8),
],
)
@pytest.mark.models_device_performance_bare_metal
def test_ag_tg_llama_perf(
ag_type,
warmup_iters,
perf_target_us,
):
profiler = BenchmarkProfiler()
benchmark_data = BenchmarkData()
step_name = f"all_gather_{ag_type}"

subdir = "llama_ccl_perf"
command = (
f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_gather_tg_llama -k {ag_type}"
)
cols = ["DEVICE KERNEL"]
op_name = "AllGatherAsync"
warmup_iters = warmup_iters * 32 # 5 iterations per device

profiler.start("run")
profiler.start(step_name)
results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters)
profiler.end(step_name)
profiler.end("run")

# Get the measured performance
measured_min_us = results[cols[0]]["MIN"] / 1000
measured_max_us = results[cols[0]]["MAX"] / 1000
measured_avg_us = results[cols[0]]["AVG"] / 1000
measured_std_us = results[cols[0]]["STD"] / 1000

logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us")

# Save the measurement
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-min-us", measured_min_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-max-us", measured_max_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-avg-us", measured_avg_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_gather-{ag_type}-std-us", measured_std_us)
benchmark_data.save_partial_run_json(
profiler,
run_type=f"all_gather",
ml_model_name="llama70b-tg-ccl",
)

assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us"


@pytest.mark.parametrize(
"ar_type, warmup_iters, perf_target_us",
[
("ff2", 10, 29),
("qkv", 10, 25),
("ff1", 10, 30),
("lm_head", 10, 70),
],
)
@pytest.mark.models_device_performance_bare_metal
def test_ar_tg_llama_perf(
ar_type,
warmup_iters,
perf_target_us,
):
profiler = BenchmarkProfiler()
benchmark_data = BenchmarkData()
step_name = f"all_reduce_{ar_type}"

subdir = "llama_ccl_perf"
command = (
f"pytest tests/ttnn/unit_tests/operations/ccl/test_ccl_async_TG_llama.py::test_all_reduce_tg_llama -k {ar_type}"
)
cols = ["DEVICE KERNEL"]
op_name = "AllReduceAsync"
warmup_iters = warmup_iters * 32 # 5 iterations per device

profiler.start("run")
profiler.start(step_name)
results = run_device_perf_detailed(command, subdir, cols, op_name, has_signposts=True, warmup_iters=warmup_iters)
profiler.end(step_name)
profiler.end("run")

# Get the measured performance
measured_min_us = results[cols[0]]["MIN"] / 1000
measured_max_us = results[cols[0]]["MAX"] / 1000
measured_avg_us = results[cols[0]]["AVG"] / 1000
measured_std_us = results[cols[0]]["STD"] / 1000

logger.info(f"Measured performance: {measured_avg_us:.3f} us vs. target: {perf_target_us} us")

# Save the measurement
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-min-us", measured_min_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-max-us", measured_max_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-avg-us", measured_avg_us)
benchmark_data.add_measurement(profiler, 0, step_name, f"all_reduce-{ar_type}-std-us", measured_std_us)
benchmark_data.save_partial_run_json(
profiler,
run_type=f"all_reduce",
ml_model_name="llama70b-tg-ccl",
)

assert measured_avg_us < perf_target_us, f"Performance target not met: {measured_avg_us} us > {perf_target_us} us"
77 changes: 77 additions & 0 deletions models/perf/device_perf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

import json
import time
import pandas as pd

from loguru import logger
from collections import defaultdict

from tt_metal.tools.profiler.common import clear_profiler_runtime_artifacts
from tt_metal.tools.profiler.process_model_log import (
get_latest_ops_log_filename,
post_process_ops_log,
run_device_profiler,
get_samples_per_s,
Expand Down Expand Up @@ -49,6 +53,79 @@ def run_device_perf(command, subdir, num_iterations, cols, batch_size, has_signp
return post_processed_results


# TODO: Move into process_model_log.py (#18698)
def post_process_ops_log_detailed(
output_logs_subdir, columns, sum_vals=True, op_name="", has_signposts=False, detailed=False, warmup_iters=0
):
filename = get_latest_ops_log_filename(output_logs_subdir)
df = pd.read_csv(filename)

if has_signposts:
# there are explicit start and stop points in the model we want to measure between
markers = df[df["OP TYPE"] == "signpost"]["OP CODE"]
start = markers[markers == "start"].index[0]
stop = markers[markers == "stop"].index[0]
df = df.iloc[start + 1 : stop]
if op_name != "":
df = df[df["OP CODE"] == op_name]

if warmup_iters > 0:
df = df.iloc[warmup_iters:]

results = {}
for col in columns:
df_filtered = df[df[col] != "-"]
if sum_vals:
results[col] = df_filtered[col].astype(float).sum()
else:
results[col] = df_filtered[col].astype(float).to_numpy()

if detailed:
results[f"AVG {col}"] = df_filtered[col].astype(float).mean()
results[f"MIN {col}"] = df_filtered[col].astype(float).min()
results[f"MAX {col}"] = df_filtered[col].astype(float).max()
results[f"STD {col}"] = df_filtered[col].astype(float).std()

return results


def run_device_perf_detailed(command, subdir, cols, op_name="", has_signposts=False, warmup_iters=0):
duration_cols = [col + " DURATION [ns]" for col in cols]

clear_profiler_runtime_artifacts()

results = {}
for d_col in duration_cols:
results[f"AVG {d_col}"] = 0
results[f"MIN {d_col}"] = float("inf")
results[f"MAX {d_col}"] = -float("inf")
results[f"STD {d_col}"] = 0

run_device_profiler(command, subdir)
r = post_process_ops_log_detailed(
subdir, duration_cols, op_name=op_name, has_signposts=has_signposts, detailed=True, warmup_iters=warmup_iters
)
for d_col in duration_cols:
results[f"AVG {d_col}"] = r[f"AVG {d_col}"]
results[f"MIN {d_col}"] = r[f"MIN {d_col}"]
results[f"MAX {d_col}"] = r[f"MAX {d_col}"]
results[f"STD {d_col}"] = r[f"STD {d_col}"]

post_processed_results = defaultdict(dict)
for col, d_col in zip(cols, duration_cols):
post_processed_results[col]["AVG"] = results[f"AVG {d_col}"]
post_processed_results[col]["MIN"] = results[f"MIN {d_col}"]
post_processed_results[col]["MAX"] = results[f"MAX {d_col}"]
post_processed_results[col]["STD"] = results[f"STD {d_col}"]

logger.info(
f"\nTest: {command}"
f"\nPerformance statistics for op: {op_name}"
f"\n{json.dumps(post_processed_results, indent=4)}"
)
return post_processed_results


def check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=False):
expected_results = {}
failed = False
Expand Down
1 change: 1 addition & 0 deletions tests/nightly/tg/ccl/test_new_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost

NUM_BUFFERS = 8


def report_mismatches(golden, actual, max_printable=None):
Expand Down Expand Up @@ -64,6 +67,7 @@ def run_with_trace(
n_worker=None,
n_buffer=None,
num_iter=20,
warmup_iters=0,
use_all_gather_async=False,
profiler=BenchmarkProfiler(),
):
Expand Down Expand Up @@ -98,47 +102,66 @@ def run_with_trace(

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
if use_all_gather_async:
logger.info("Running all-gather async")
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)

def capture_trace(n_iters):
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(n_iters):
if use_all_gather_async:
tt_out_tensor = ttnn.experimental.all_gather_async(
input_tensor,
dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS]
if type(ccl_semaphore_handles) == list
else ccl_semaphore_handles,
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
enable_persistent_fabric_mode=enable_persistent_fabric,
)
else:
tt_out_tensor = ttnn.all_gather(
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
num_links=num_links,
memory_config=output_mem_config,
topology=all_gather_topology,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
ttnn.synchronize_device(mesh_device)
return trace_id

if warmup_iters > 0:
trace_id_warmup = capture_trace(warmup_iters)
trace_id = capture_trace(num_iter)

# Run the op
logger.info("Starting Trace perf test...")
profiler.start("all-gather-async-trace-warmup")
if warmup_iters > 0:
ttnn.execute_trace(mesh_device, trace_id_warmup, blocking=False)
ttnn.release_trace(mesh_device, trace_id_warmup)
ttnn.synchronize_device(mesh_device)
profiler.end("all-gather-async-trace-warmup")

profiler.start("all-gather-async-trace")
signpost("start")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
ttnn.synchronize_device(mesh_device)
signpost("stop")
profiler.end("all-gather-async-trace")
logger.info(f"Time taken: {profiler.get_duration('all-gather-async-trace')} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter} s")
logger.info(f"Time per iter: {(profiler.get_duration('all-gather-async-trace')) / num_iter * 1e6} us")
time_taken = profiler.get_duration("all-gather-async-trace") - profiler.get_duration(
"all-gather-async-trace-warmup"
)
effective_iter = num_iter - warmup_iters
logger.info(f"Time taken e2e: {time_taken} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter} s")
logger.info(f"Time per iter e2e: {time_taken / effective_iter * 1e6} us")

return tt_out_tensor

Expand All @@ -160,6 +183,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
output_shard_spec: ttnn.ShardSpec = None,
num_all_gather_instances: int = 1,
num_iters: int = 1,
warmup_iters: int = 0,
cluster_axis: int = 0,
tile=(32, 32),
trace_mode=False,
Expand Down Expand Up @@ -257,7 +281,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
]
try:
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
Expand All @@ -274,11 +298,13 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
enable_persistent_fabric=enable_persistent_fabric,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
warmup_iters=warmup_iters,
use_all_gather_async=use_all_gather_async,
profiler=profiler,
)

else:
signpost("start")
for i in range(num_iters):
if use_all_gather_async:
logger.info("Running all-gather async")
Expand All @@ -288,7 +314,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
cluster_axis=cluster_axis,
mesh_device=mesh_device,
topology=ttnn.Topology.Linear,
multi_device_global_semaphore=ccl_semaphore_handles[i],
multi_device_global_semaphore=ccl_semaphore_handles[i % NUM_BUFFERS],
num_links=num_links,
memory_config=output_mem_config,
subdevice_id=worker_sub_device_id,
Expand All @@ -305,7 +331,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(
topology=ttnn.Topology.Linear,
)
ttnn.synchronize_device(mesh_device, sub_device_ids=sub_device_stall_group)

signpost("stop")
except Exception as e:
logger.error(f"Exception: {e}")
raise e
Expand Down
Loading
Loading