Skip to content

Commit

Permalink
#18198: If shard width is tt::constants::TILE_WIDTH, try to do tt::co…
Browse files Browse the repository at this point in the history
…nstants::TILE_WIDTH / 2 if possible
  • Loading branch information
skrsticTT committed Mar 5, 2025
1 parent 1cee76b commit 48d1b02
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
9 changes: 8 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ ParallelConfig determine_parallel_config(
ShardOrientation block_shard_orientation,
bool enable_channels_padding,
bool is_out_tiled,
uint32_t act_block_h_override) {
uint32_t act_block_h_override,
bool width_shard_half_tile) {
uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1;
uint32_t effective_tile_width = tt::constants::TILE_WIDTH;
uint32_t out_nhw_ntiles =
Expand Down Expand Up @@ -123,6 +124,9 @@ ParallelConfig determine_parallel_config(
? find_closest_largest_divisor_with_num_padding(
out_channels_ntiles, input_channles_ntiles, start_divisor_c)
: find_closest_largest_divisor(out_channels_ntiles, input_channles_ntiles, start_divisor_c);
if (width_shard_half_tile && input_channles_ntiles / num_cores_c == 1 && 2 * num_cores_c <= start_divisor_c) {
num_cores_c *= 2;
}
uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c;
uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw;
CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1}));
Expand All @@ -131,6 +135,9 @@ ParallelConfig determine_parallel_config(
uint32_t num_cores_c = enable_channels_padding
? find_closest_largest_divisor_with_num_padding(input_channles_ntiles, max_num_cores)
: find_closest_largest_divisor(input_channles_ntiles, max_num_cores);
if (width_shard_half_tile && input_channles_ntiles / num_cores_c == 1 && 2 * num_cores_c <= max_num_cores) {
num_cores_c *= 2;
}
grid = tt::tt_metal::num_cores_to_corerangeset(num_cores_c, compute_grid_size, true);
} else {
TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout);
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ sliding_window::ParallelConfig determine_parallel_config(
tt::tt_metal::ShardOrientation block_shard_orientation,
bool enable_channels_padding,
bool is_out_tiled = true,
uint32_t act_block_h_override = 0);
uint32_t act_block_h_override = 0,
bool width_shard_half_tile = false);

sliding_window::ParallelConfig determine_output_parallel_config(
const sliding_window::ParallelConfig& input_parallel_config,
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/pool/generic/generic_pools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ Tensor Pool2DOp<pool_type>::invoke(
input_tensor.device()->compute_with_storage_grid_size(),
ShardOrientation::ROW_MAJOR,
false,
false);
false,
0,
true);
num_cores_nhw = conv::get_num_cores_nhw_from_parallel_config(parallel_config);
num_cores_c = conv::get_num_cores_channels_from_parallel_config(parallel_config);
auto sharded_mem_config = conv::create_sharded_memory_config_from_parallel_config(input_tensor_sharded.get_padded_shape(), parallel_config, is_in_tiled ? tt::constants::TILE_HEIGHT : 1);
Expand Down

0 comments on commit 48d1b02

Please sign in to comment.