diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py index eff32fdee1c..447f367e095 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_model_utils.py @@ -63,7 +63,6 @@ def get_conv_input_memory_config( compute_grid_size=compute_grid, block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=True, - is_out_tiled=True, ) if override_num_cores: diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 0072f0ee88c..6c94a358e7b 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -97,7 +97,6 @@ def __init__( compute_grid_size=self.device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=False, - is_out_tiled=True, ), tile_size=32, ) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index 58f3ab618b0..ff0c4383fb9 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -129,7 +129,6 @@ def __init__( compute_grid_size=self.device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=False, - is_out_tiled=True, ), tile_size=32, ) @@ -207,7 +206,6 @@ def __init__( compute_grid_size=self.device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=False, - is_out_tiled=True, ), tile_size=32, ) diff --git a/tests/ttnn/nightly/unit_tests/operations/max_pool2d/test_maxpool2d.py b/tests/ttnn/nightly/unit_tests/operations/max_pool2d/test_maxpool2d.py index f9f9f01a7b9..001fdda9338 100644 --- a/tests/ttnn/nightly/unit_tests/operations/max_pool2d/test_maxpool2d.py +++ b/tests/ttnn/nightly/unit_tests/operations/max_pool2d/test_maxpool2d.py @@ -191,7 +191,8 @@ def run_max_pool( compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=False, - is_out_tiled=False, + is_shard_height_tile_multiple=False, + is_shard_width_tile_multiple=False, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( tensor_shape=ttact_device.shape, @@ -834,7 +835,8 @@ def test_pool_core_nondivis( compute_grid_size=device.compute_with_storage_grid_size(), block_shard_orientation=ttnn.ShardOrientation.ROW_MAJOR, enable_channels_padding=False, - is_out_tiled=True, + is_shard_height_tile_multiple=True, + is_shard_width_tile_multiple=True, ) sharded_memory_config = ttnn._ttnn.operations.conv.create_sharded_memory_config_from_parallel_config( tensor_shape=ttact_device.shape, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index d12ee00da8f..587e6002107 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -284,7 +284,8 @@ void py_bind_conv2d(py::module& module) { const CoreCoord& compute_grid_size, tt::tt_metal::ShardOrientation block_shard_orientation, bool enable_channels_padding, - bool is_out_tiled) -> ttnn::operations::sliding_window::ParallelConfig { + bool is_shard_height_tile_multiple, + bool is_shard_width_tile_multiple) -> ttnn::operations::sliding_window::ParallelConfig { return determine_parallel_config( shard_layout, batch_size, @@ -295,7 +296,8 @@ void py_bind_conv2d(py::module& module) { compute_grid_size, block_shard_orientation, enable_channels_padding, - is_out_tiled); + is_shard_height_tile_multiple, + is_shard_width_tile_multiple); }, py::arg("shard_layout"), py::arg("batch_size"), @@ -306,7 +308,8 @@ void py_bind_conv2d(py::module& module) { py::arg("compute_grid_size"), py::arg("block_shard_orientation"), py::arg("enable_channels_padding"), - py::arg("is_out_tiled") = true); + py::arg("is_shard_height_tile_multiple") = true, + py::arg("is_shard_width_tile_multiple") = true); module.def( "create_sharded_memory_config_from_parallel_config", diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp index 0c90e608f59..c190ebd3ca9 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.cpp @@ -81,6 +81,16 @@ uint32_t find_closest_largest_divisor_with_num_padding(uint32_t num1, uint32_t n return divisor; } +// If shard width is tile width, and it is allowed to have half tile shard width, and we have enough cores to do it, +// double number of cores +static void set_shard_width_to_half_tile_if_possible( + uint32_t& num_cores, uint32_t channels_ntiles, uint32_t max_num_cores, bool width_shard_half_tile_possible) { + if (width_shard_half_tile_possible && (div_up(channels_ntiles, num_cores) == 1) && + (2 * num_cores <= max_num_cores)) { + num_cores *= 2; + } +} + ParallelConfig determine_parallel_config( const TensorMemoryLayout shard_layout, uint32_t batch_size, @@ -91,12 +101,15 @@ ParallelConfig determine_parallel_config( const CoreCoord& compute_grid_size, ShardOrientation block_shard_orientation, bool enable_channels_padding, - bool is_out_tiled, + bool is_shard_height_tile_multiple, + bool is_shard_width_tile_multiple, uint32_t act_block_h_override) { - uint32_t effective_tile_height = is_out_tiled ? tt::constants::TILE_HEIGHT : 1; - uint32_t effective_tile_width = is_out_tiled ? tt::constants::TILE_WIDTH : 1; - uint32_t out_nhw_ntiles = - tt::round_up(batch_size * output_height * output_width, tt::constants::TILE_HEIGHT) / effective_tile_height; + // Currently, convolution requires multiples of the tile size for both shard height and width, + // while pooling can accept any height and either a tile multiple or half a tile for width. + // This approach needs to be modified when other shard dimensions are supported. + uint32_t effective_tile_height = is_shard_height_tile_multiple ? tt::constants::TILE_HEIGHT : 1; + uint32_t effective_tile_width = tt::constants::TILE_WIDTH; + uint32_t out_nhw_ntiles = tt::div_up(batch_size * output_height * output_width, effective_tile_height); uint32_t input_channles_ntiles = tt::div_up(input_channels, effective_tile_width); uint32_t out_channels_ntiles = tt::div_up(output_channels, effective_tile_width); // In case non native activation block height is used, we need to ensure that the amount @@ -123,6 +136,8 @@ ParallelConfig determine_parallel_config( ? find_closest_largest_divisor_with_num_padding( out_channels_ntiles, input_channles_ntiles, start_divisor_c) : find_closest_largest_divisor(out_channels_ntiles, input_channles_ntiles, start_divisor_c); + set_shard_width_to_half_tile_if_possible( + num_cores_c, input_channles_ntiles, start_divisor_c, !is_shard_width_tile_multiple); uint32_t cores_x = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_nhw : num_cores_c; uint32_t cores_y = block_shard_orientation == ShardOrientation::COL_MAJOR ? num_cores_c : num_cores_nhw; CoreRange core_range = CoreRange(CoreCoord({0, 0}), CoreCoord({cores_x - 1, cores_y - 1})); @@ -131,6 +146,8 @@ ParallelConfig determine_parallel_config( uint32_t num_cores_c = enable_channels_padding ? find_closest_largest_divisor_with_num_padding(input_channles_ntiles, max_num_cores) : find_closest_largest_divisor(input_channles_ntiles, max_num_cores); + set_shard_width_to_half_tile_if_possible( + num_cores_c, input_channles_ntiles, max_num_cores, !is_shard_width_tile_multiple); grid = tt::tt_metal::num_cores_to_corerangeset(num_cores_c, compute_grid_size, true); } else { TT_THROW("Conv2d supports Height, Block or Width Sharded Layouts but got {}", shard_layout); @@ -475,6 +492,7 @@ static std::tuple get_conv_padded_input_s block_shard_orientation, !is_mm_conv, true, + true, conv_config.act_block_h_override); if (conv_config.override_sharding_config) { @@ -703,7 +721,6 @@ Conv2dConfig determine_conv_config_for_auto_shard( ShardOrientation shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; - const bool is_out_tiled = conv_config.output_layout == Layout::TILE; struct core_count_and_size { uint32_t core_count; @@ -753,7 +770,8 @@ Conv2dConfig determine_conv_config_for_auto_shard( compute_grid_size, shard_orientation, !is_mm_conv, - is_out_tiled, + true, + true, conv_config.act_block_h_override); const ParallelConfig output_parallel_config = determine_output_parallel_config( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp index c28026849fc..498aa2b4e40 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_utils.hpp @@ -47,7 +47,8 @@ sliding_window::ParallelConfig determine_parallel_config( const CoreCoord& compute_grid_size, tt::tt_metal::ShardOrientation block_shard_orientation, bool enable_channels_padding, - bool is_out_tiled = true, + bool is_shard_height_tile_multiple = true, + bool is_shard_width_tile_multiple = true, uint32_t act_block_h_override = 0); sliding_window::ParallelConfig determine_output_parallel_config( diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp index 2a6ce8a9281..a15b794fc61 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/prepare_conv2d_weights.cpp @@ -599,6 +599,7 @@ static OptimizedConvBlockConfig get_opt_block_config( shard_orientation, !mm_conv, true, + true, conv_config.act_block_h_override); } auto output_parallel_config = parallel_config; @@ -839,6 +840,7 @@ ttnn::Tensor prepare_conv_weights( shard_orientation, !mm_conv, true, + true, conv_config.act_block_h_override); } @@ -941,6 +943,7 @@ ttnn::Tensor prepare_conv_bias( shard_orientation, !mm_conv, true, + true, conv_config.act_block_h_override); }