Skip to content

Commit

Permalink
#18332: Fix BN hang for FPU Kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Mar 6, 2025
1 parent e4dc25d commit 4e2a1b4
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 48 deletions.
8 changes: 5 additions & 3 deletions tests/ttnn/unit_tests/operations/test_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ def test_batch_norm_fp32(
@pytest.mark.parametrize(
"input_shapes",
[
*(torch.Size([n, c, 32, 32]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 23, 23]) for n, c in product([3, 4], [3, 4])),
*(torch.Size([n, c, 64, 120]) for n, c in product([2, 3], [3, 4])),
*(torch.Size([n, c, 32, 32]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 23, 23]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 64, 120]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
*(torch.Size([n, c, 1024, 1024]) for n, c in product([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6, 7, 8])),
torch.Size([3, 6, 4096, 4096]),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,9 @@ ALWI void batchnorm_bcast_tiles(
auto cb_affine_or_out = (weight_has_value || bias_has_value) ? cb_tmp_1 : cb_output_0;
auto cb_scaled_output = (bias_has_value) ? cb_tmp_1 : cb_output_0;

cb_wait_front(cb_bcast, onetile);

for (uint32_t j = tile_start; j < freq; ++j) {
cb_wait_front(cb_other, onetile);
cb_reserve_back(cb_num, onetile);

tile_regs_acquire();
sub_tiles(cb_other, cb_bcast, 0, 0, 0);
tile_regs_commit();

tile_regs_wait();
pack_tile(0, cb_num);
tile_regs_release();

cb_push_back(cb_num, onetile);
cb_pop_front(cb_other, onetile);
}
cb_pop_front(cb_bcast, onetile);

// 1/(sqrt(batch_var + eps))
cb_reserve_back(cb_den, onetile);
cb_wait_front(cb_batch_var, 1);
cb_wait_front(cb_batch_var, onetile);

tile_regs_acquire();
add_tiles_init_with_dt(cb_batch_var, cb_eps);
Expand All @@ -65,32 +46,40 @@ ALWI void batchnorm_bcast_tiles(
pack_tile_with_dt(dst0, cb_den);
tile_regs_release();

cb_pop_front(cb_batch_var, 1);
cb_pop_front(cb_batch_var, onetile);
cb_push_back(cb_den, onetile);

// (input - batch_mean)/(sqrt(batch_var + eps)) = result
cb_wait_front(cb_den, 1);
cb_wait_front(cb_bcast, onetile);
cb_wait_front(cb_den, onetile);
if (weight_has_value) {
cb_wait_front(cb_weight, onetile);
}
if (bias_has_value) {
cb_wait_front(cb_bias, onetile);
}
for (uint32_t j = tile_start; j < freq; ++j) {
cb_wait_front(cb_num, 1);
// input - batch_mean
cb_wait_front(cb_other, onetile);
cb_reserve_back(cb_affine_or_out, onetile);

tile_regs_acquire();
mul_tiles_init_with_dt(cb_num, cb_den);
mul_tiles(cb_num, cb_den, 0, 0, dst0);
sub_tiles_init(cb_other, cb_bcast);
sub_tiles(cb_other, cb_bcast, 0, 0, 0);

// (input - batch_mean)/(sqrt(batch_var + eps)) = result
binary_dest_reuse_tiles_init<EltwiseBinaryType::ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCA>(cb_den);
binary_dest_reuse_tiles<EltwiseBinaryType::ELWMUL, EltwiseBinaryReuseDestType::DEST_TO_SRCA>(cb_den, 0, 0);
tile_regs_commit();

tile_regs_wait();
pack_tile_with_dt(dst0, cb_affine_or_out);
pack_tile_with_dt(0, cb_affine_or_out);
tile_regs_release();

cb_pop_front(cb_num, 1);
cb_push_back(cb_affine_or_out, onetile);
}
cb_pop_front(cb_den, 1);
cb_pop_front(cb_other, onetile);

if (weight_has_value) { // result = result * weight
cb_wait_front(cb_weight, 1);
for (uint32_t j = tile_start; j < freq; ++j) {
// result = result * weight
if (weight_has_value) {
cb_reserve_back(cb_scaled_output, onetile);
cb_wait_front(cb_affine_or_out, 1);

Expand All @@ -104,16 +93,13 @@ ALWI void batchnorm_bcast_tiles(
tile_regs_release();

cb_pop_front(cb_affine_or_out, 1);

cb_push_back(cb_scaled_output, onetile);
}
cb_pop_front(cb_weight, 1);
}
if (bias_has_value) { // result = result + bias
cb_wait_front(cb_bias, 1);
for (uint32_t j = tile_start; j < freq; ++j) {
cb_reserve_back(cb_output_0, 1);
cb_wait_front(cb_tmp_1, 1);

// result = result + bias
if (bias_has_value) {
cb_reserve_back(cb_output_0, onetile);
cb_wait_front(cb_tmp_1, onetile);

tile_regs_acquire();
add_tiles_init_with_dt(cb_tmp_1, cb_bias);
Expand All @@ -124,10 +110,17 @@ ALWI void batchnorm_bcast_tiles(
pack_tile_with_dt(dst0, cb_output_0);
tile_regs_release();

cb_pop_front(cb_tmp_1, 1);
cb_push_back(cb_output_0, 1);
cb_pop_front(cb_tmp_1, onetile);
cb_push_back(cb_output_0, onetile);
}
cb_pop_front(cb_bias, 1);
}
cb_pop_front(cb_bcast, onetile);
cb_pop_front(cb_den, onetile);
if (weight_has_value) {
cb_pop_front(cb_weight, onetile);
}
if (bias_has_value) {
cb_pop_front(cb_bias, onetile);
}
}

Expand Down Expand Up @@ -159,7 +152,6 @@ void MAIN {

binary_op_init_common(cb_other, cb_bcast, cb_output_0);

sub_tiles_init(cb_other, cb_bcast);
uint32_t complete_iterations = (num_tiles + tile_start) / tile_freq;
uint32_t remaining_iterations = (num_tiles + tile_start) % tile_freq;

Expand Down

0 comments on commit 4e2a1b4

Please sign in to comment.