Skip to content

Commit

Permalink
#18198: Added padding for non-tile multiple channels
Browse files Browse the repository at this point in the history
  • Loading branch information
skrsticTT committed Mar 4, 2025
1 parent 3b6e165 commit 15ab5ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
11 changes: 9 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,19 @@ MemoryConfig create_sharded_memory_config_from_parallel_config(

uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2];
uint32_t nhw_padded = nhw_shape;
uint32_t c_padded = channels;
if (shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) {
nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size);
}
// non-tile multiple is possible just if C = 16, otherwise we pad shards to 32
if (shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED &&
(channels / num_cores_channels) % tt::constants::TILE_WIDTH != 0 &&
(channels / num_cores_channels) > tt::constants::TILE_WIDTH) {
c_padded = round_up(channels, num_cores_channels * tt::constants::TILE_WIDTH);
}
uint32_t nhw_shard = nhw_padded / num_cores_nhw;
TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels);
uint32_t channel_shard = channels / num_cores_channels;
TT_ASSERT(c_padded % num_cores_channels == 0, "Channels: {}, num core channels: {}", c_padded, num_cores_channels);
uint32_t channel_shard = c_padded / num_cores_channels;
auto shard_spec = tt::tt_metal::ShardSpec{parallel_config.grid, {nhw_shard, channel_shard}, shard_orientation};
log_debug("Calculated Shard Spec = {}", shard_spec);
return MemoryConfig{shard_scheme, BufferType::L1, shard_spec};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
uint32_t in_nbytes = datum_size(in_df);
uint32_t out_nbytes = datum_size(out_df);

uint32_t in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes; // row of input (channels)
uint32_t in_nbytes_c;
// non-tile multiple is possible just if C = 16, otherwise we pad shards to 32
if ((input_shape[3] / num_shards_c) % tt::constants::TILE_WIDTH != 0 &&
(input_shape[3] / num_shards_c) > tt::constants::TILE_WIDTH) {
in_nbytes_c = tt::round_up(input_shape[3], num_shards_c * tt::constants::TILE_WIDTH) / num_shards_c * in_nbytes;
} else {
in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes;
}
uint32_t out_nbytes_c = output_shape[3] / num_shards_c * out_nbytes; // row of output (channels)

tt::DataFormat indices_df =
Expand Down

0 comments on commit 15ab5ef

Please sign in to comment.