Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added padding for non-tile multiple channels #18595

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,19 @@ MemoryConfig create_sharded_memory_config_from_parallel_config(

uint32_t nhw_shape = tensor_shape[0] * tensor_shape[1] * tensor_shape[2];
uint32_t nhw_padded = nhw_shape;
uint32_t c_padded = channels;
if (shard_scheme != TensorMemoryLayout::WIDTH_SHARDED) {
nhw_padded = round_up(nhw_shape, num_cores_nhw * tile_size);
}
// non-tile multiple is possible just if C = 16, otherwise we pad shards to 32
if (shard_scheme != TensorMemoryLayout::HEIGHT_SHARDED &&
(channels / num_cores_channels) % tt::constants::TILE_WIDTH != 0 &&
(channels / num_cores_channels) > tt::constants::TILE_WIDTH) {
c_padded = round_up(channels, num_cores_channels * tt::constants::TILE_WIDTH);
}
uint32_t nhw_shard = nhw_padded / num_cores_nhw;
TT_ASSERT(channels % num_cores_channels == 0, "Channels: {}, num core channels: {}", channels, num_cores_channels);
uint32_t channel_shard = channels / num_cores_channels;
TT_ASSERT(c_padded % num_cores_channels == 0, "Channels: {}, num core channels: {}", c_padded, num_cores_channels);
uint32_t channel_shard = c_padded / num_cores_channels;
auto shard_spec = tt::tt_metal::ShardSpec{parallel_config.grid, {nhw_shard, channel_shard}, shard_orientation};
log_debug("Calculated Shard Spec = {}", shard_spec);
return MemoryConfig{shard_scheme, BufferType::L1, shard_spec};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ Pool2D::MultiCore::cached_program_t pool2d_multi_core_sharded_with_halo_v2_impl_
uint32_t in_nbytes = datum_size(in_df);
uint32_t out_nbytes = datum_size(out_df);

uint32_t in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes; // row of input (channels)
uint32_t in_nbytes_c;
// non-tile multiple is possible just if C = 16, otherwise we pad shards to 32
if ((input_shape[3] / num_shards_c) % tt::constants::TILE_WIDTH != 0 &&
(input_shape[3] / num_shards_c) > tt::constants::TILE_WIDTH) {
in_nbytes_c = tt::round_up(input_shape[3], num_shards_c * tt::constants::TILE_WIDTH) / num_shards_c * in_nbytes;
} else {
in_nbytes_c = input_shape[3] / num_shards_c * in_nbytes;
}
uint32_t out_nbytes_c = output_shape[3] / num_shards_c * out_nbytes; // row of output (channels)

tt::DataFormat indices_df =
Expand Down
Loading