Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: temp workaround on TG resnet trace+2cq hang #18750

Merged
merged 1 commit into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_unet_trace_perf(
"batch, groups, iterations, expected_compile_time, expected_throughput, use_async_mode",
(
(1, 2, 128, 25.0, 1450.0, True),
(1, 2, 128, 25.0, 1660.0, False),
(1, 2, 128, 25.0, 1650.0, False),
),
)
def test_unet_trace_perf_multi_device(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ void kernel_main() {
uint32_t out_tensor_start_tile_id = get_arg_val<uint32_t>(rt_args_idx++);

// padding args (WRITER)
const uint32_t last_num_blocks_h_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_subblock_h = get_arg_val<uint32_t>(rt_args_idx++);
Expand All @@ -33,6 +31,11 @@ void kernel_main() {
const uint32_t padded_subblock_tiles_addr_skip = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t padded_block_tiles_w_skip = get_arg_val<uint32_t>(rt_args_idx++);

#ifndef OUT_SHARDED
const uint32_t last_num_blocks_h_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
#endif

// COMPILE TIME ARGS
// interleaved accessor args
constexpr bool out_is_dram = get_compile_time_arg_val(0) == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void kernel_main() {
// padding args (READER)
const uint32_t last_block_w = get_arg_val<uint32_t>(rt_args_idx++);
// padding args (WRITER)
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_num_nonzero_subblocks_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t out_last_subblock_h = get_arg_val<uint32_t>(rt_args_idx++);
const uint32_t padded_block_tiles_h_skip = get_arg_val<uint32_t>(rt_args_idx++);
Expand Down Expand Up @@ -108,6 +107,9 @@ void kernel_main() {
#else
rt_args_idx += 2; // Skip over placeholders
#endif
#ifndef OUT_SHARDED
const uint32_t last_num_blocks_w_dim = get_arg_val<uint32_t>(rt_args_idx++);
#endif

constexpr bool fuse_op = (bool)get_compile_time_arg_val(31);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(last_out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -852,7 +851,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -871,6 +869,13 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
mm_in1_sender_writer_args.push_back(0);
mm_in1_sender_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (output_idx_x == num_blocks_x - 1) {
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}
}

if (fuse_op) {
fused_op_signaler->push_matmul_fused_op_rt_args(mm_in1_sender_writer_args, true);
Expand Down Expand Up @@ -943,7 +948,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0(
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[18] = (*bias_buffer)->address();
writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down Expand Up @@ -1531,7 +1536,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
// padding args (READER)
(std::uint32_t)out_block_w, // last_block_w
// padding args (WRITER)
(std::uint32_t)out_num_blocks_x,
(std::uint32_t)out_block_h / out_subblock_h,
(std::uint32_t)out_subblock_h,
(std::uint32_t)0,
Expand All @@ -1545,6 +1549,12 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_sender_writer_args.push_back((std::uint32_t)bias_buffer->address());
mm_in1_sender_writer_args.push_back(
(std::uint32_t)per_core_N * output_idx_x); // in3_tensor_start_tile_id
} else {
mm_in1_sender_writer_args.push_back(0);
mm_in1_sender_writer_args.push_back(0);
}
if (!output_is_sharded) {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}

tt_metal::SetRuntimeArgs(
Expand All @@ -1566,8 +1576,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(

if (output_idx_y == num_blocks_y - 1) {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1579,8 +1587,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_receiver_writer_args.push_back(0);
} else {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1591,6 +1597,15 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
mm_in1_receiver_writer_args.push_back(0);
mm_in1_receiver_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (output_idx_y == num_blocks_y - 1) {
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
} else {
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
}
}

tt_metal::SetRuntimeArgs(
program,
Expand Down Expand Up @@ -1659,7 +1674,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in1(
sender_writer_runtime_args[0] = src_buffer_b->address();
sender_writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
sender_writer_runtime_args[18] = (*bias_buffer)->address();
sender_writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(last_out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -1043,7 +1042,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(out_block_w);

// padding args (WRITER)
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
mm_in1_sender_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_sender_writer_args.push_back(out_subblock_h);
mm_in1_sender_writer_args.push_back(0);
Expand All @@ -1062,6 +1060,13 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_sender_writer_args.push_back(0); // Placeholder; not used
mm_in1_sender_writer_args.push_back(0); // Placeholder; not used
}
if (!output_is_sharded) {
if (in1_idx == in1_end_idx) { // right cores when no transpose_mcast
mm_in1_sender_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_sender_writer_args.push_back(out_num_blocks_x);
}
}

if (in1_is_sharded and in1_is_dram) { // in1 is dram sharded
uint32_t num_iter_index = mm_in1_sender_writer_args.size() + 1;
Expand Down Expand Up @@ -1147,8 +1152,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(

if (in1_idx == in1_end_idx and in0_idx == in0_end_idx) { // bottom-right core when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1160,8 +1163,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_w_skip);
} else if (in0_idx == in0_end_idx) { // bottom cores except bottom-right when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(last_block_num_nonzero_subblocks_h);
mm_in1_receiver_writer_args.push_back(last_subblock_of_last_block_h);
Expand All @@ -1173,8 +1174,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(0);
} else if (in1_idx == in1_end_idx) { // right cores except bottom when no transpose_mcast
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1186,8 +1185,6 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(last_block_padded_block_tiles_w_skip);
} else {
// padding args (WRITER)
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_block_h / out_subblock_h);
mm_in1_receiver_writer_args.push_back(out_subblock_h);
Expand All @@ -1198,6 +1195,22 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in1_receiver_writer_args.push_back(0);
mm_in1_receiver_writer_args.push_back(0);
}
if (!output_is_sharded) {
if (in1_idx == in1_end_idx and
in0_idx == in0_end_idx) { // bottom-right core when no transpose_mcast
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
} else if (in0_idx == in0_end_idx) { // bottom cores except bottom-right when no transpose_mcast
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_h);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
} else if (in1_idx == in1_end_idx) { // right cores except bottom when no transpose_mcast
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(last_out_num_blocks_w);
} else {
mm_in1_receiver_writer_args.push_back(out_num_blocks_y);
mm_in1_receiver_writer_args.push_back(out_num_blocks_x);
}
}

// left half
if (core.x <= half_core || (transpose_mcast and core.y == start_core_y)) {
Expand Down Expand Up @@ -1270,7 +1283,7 @@ tt::tt_metal::operation::ProgramWithCallbacks create_program_mcast_in0_in1(
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[18] = (*bias_buffer)->address();
writer_runtime_args[17] = (*bias_buffer)->address();
}
}

Expand Down
Loading