From 8f65743d13dc3b01efa9d9eedfc0a9dc35b5672f Mon Sep 17 00:00:00 2001 From: asaigal Date: Sun, 9 Mar 2025 03:17:32 +0000 Subject: [PATCH] Remove MultiDeviceGlobalSemaphore - TT-Metal already supports Global Sems backed by MeshBuffer - Expose this concept to TTNN, and remove the concept of the MultiDeviceGlobalSemaphore - create_global_semaphore(mesh_device, ...) is now hooked to GlobalSemaphore - Remove create_global_semaphore_with_same_address, since global sem buffer allocation is now through the MeshAllocator --- .../gtests/ccl/test_fabric_edm_common.hpp | 30 ++--- ...erisc_data_mover_loopback_with_workers.cpp | 30 ++--- .../ccl/test_all_gather_TG_post_commit.py | 3 +- .../ccl/test_all_gather_async_TG_nightly.py | 1 - .../operations/ccl/test_all_reduce_async.py | 13 +- .../operations/ccl/test_ccl_common.py | 8 -- .../operations/ccl/test_new_all_gather.py | 5 +- .../operations/ccl/test_new_all_reduce.py | 3 +- .../ccl/test_reduce_scatter_TG_nightly.py | 5 +- .../ccl/test_reduce_scatter_async.py | 5 +- .../test_reduce_scatter_async_TG_nightly.py | 1 - .../ttnn/unit_tests/test_global_semaphore.py | 31 ----- ttnn/cpp/pybind11/global_semaphore.cpp | 56 +-------- ttnn/cpp/ttnn/global_semaphore.cpp | 119 +----------------- ttnn/cpp/ttnn/global_semaphore.hpp | 30 +---- .../ccl/all_gather_async/all_gather_async.cpp | 4 +- .../ccl/all_gather_async/all_gather_async.hpp | 4 +- .../all_gather_async_pybind.cpp | 4 +- .../device/all_gather_async_op.cpp | 21 ++-- .../device/all_gather_async_op.hpp | 4 +- .../ccl/all_reduce_async/all_reduce_async.cpp | 14 +-- .../ccl/all_reduce_async/all_reduce_async.hpp | 14 +-- .../all_reduce_async_pybind.cpp | 14 +-- .../device/all_reduce_async_op.cpp | 13 +- .../device/all_reduce_async_op.hpp | 2 +- .../device/reduce_scatter_async_op.cpp | 43 +++---- .../device/reduce_scatter_async_op.hpp | 8 +- .../reduce_scatter_async/reduce_scatter.cpp | 8 +- .../reduce_scatter_async/reduce_scatter.hpp | 8 +- .../reduce_scatter_pybind.cpp | 8 +- ttnn/ttnn/__init__.py | 1 - 31 files changed, 127 insertions(+), 383 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_edm_common.hpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_edm_common.hpp index b61e6cf2972..59185b2c6d2 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_edm_common.hpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_edm_common.hpp @@ -2029,14 +2029,12 @@ void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_li num_links); log_info(tt::LogTest, "Lauching op"); - ttnn::global_semaphore::MultiDeviceGlobalSemaphore multi_device_global_semaphore = - ttnn::global_semaphore::create_global_semaphore_with_same_address( - test_fixture.mesh_device_.get(), - devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), - 0, // initial value - tt::tt_metal::BufferType::L1, // buffer type - 10 // attempts - ); + GlobalSemaphore multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore( + test_fixture.mesh_device_.get(), + test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1 // buffer type + ); auto output_tensor = ttnn::operations::experimental::ccl::all_gather_async( input_mesh_tensor, @@ -2161,18 +2159,16 @@ void RunWriteThroughputStabilityTestWithPersistentFabric( std::vector global_semaphore_addrs; global_semaphore_addrs.reserve(line_size + 1); - std::vector global_semaphore_handles; + std::vector global_semaphore_handles; for (size_t i = 0; i < line_size * 4; i++) { - auto global_semaphores = ttnn::global_semaphore::create_global_semaphore_with_same_address( + auto global_semaphore = ttnn::global_semaphore::create_global_semaphore( test_fixture.mesh_device_.get(), - devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), - 0, // initial value - tt::tt_metal::BufferType::L1, // buffer type - 1000 // attempts + test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1 // buffer type ); - global_semaphore_handles.push_back(global_semaphores); - auto global_semaphore_addr = - ttnn::global_semaphore::get_global_semaphore_address(global_semaphores.global_semaphores.at(0)); + global_semaphore_handles.push_back(global_semaphore); + auto global_semaphore_addr = ttnn::global_semaphore::get_global_semaphore_address(global_semaphore); global_semaphore_addrs.push_back(global_semaphore_addr); } diff --git a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp index 08154d4a04c..a0278884db1 100644 --- a/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp +++ b/tests/ttnn/unit_tests/gtests/ccl/test_fabric_erisc_data_mover_loopback_with_workers.cpp @@ -857,23 +857,19 @@ TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) { enable_persistent_fabric, num_links); - ttnn::global_semaphore::MultiDeviceGlobalSemaphore from_remote_multi_device_global_semaphore = - ttnn::global_semaphore::create_global_semaphore_with_same_address( - test_fixture.mesh_device_.get(), - devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), - 0, // initial value - tt::tt_metal::BufferType::L1, // buffer type - 10 // attempts - ); - - ttnn::global_semaphore::MultiDeviceGlobalSemaphore to_remote_multi_device_global_semaphore = - ttnn::global_semaphore::create_global_semaphore_with_same_address( - test_fixture.mesh_device_.get(), - devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), - 0, // initial value - tt::tt_metal::BufferType::L1, // buffer type - 10 // attempts - ); + GlobalSemaphore from_remote_multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore( + test_fixture.mesh_device_.get(), + test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1 // buffer type + ); + + GlobalSemaphore to_remote_multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore( + test_fixture.mesh_device_.get(), + test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}), + 0, // initial value + tt::tt_metal::BufferType::L1 // buffer type + ); auto output_tensor = ttnn::operations::experimental::ccl::reduce_scatter( input_mesh_tensor, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py index efac85c97c2..882b5d28d19 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_TG_post_commit.py @@ -12,7 +12,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) from models.perf.benchmarking_utils import BenchmarkProfiler from tracy import signpost @@ -281,7 +280,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows( # create global semaphore handles ccl_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS) + ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS) ] try: # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py index b572de93aab..4c3bdbacb2e 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_async_TG_nightly.py @@ -11,7 +11,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py index 542a7765a99..4c152e6109f 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py @@ -12,7 +12,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) @@ -63,9 +62,9 @@ def run_all_reduce_test( sub_device_stall_group = [worker_sub_device_id] mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) - to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) - gather_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + from_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) + to_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) + gather_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) debug = False @@ -341,9 +340,9 @@ def run_all_reduce_with_mesh_tensor_along_row( sub_device_stall_group = [worker_sub_device_id] mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) - to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) - gather_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + from_remote_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) + to_remote_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) + gather_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) try: debug = False diff --git a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py index 1029c113749..c01e0a4d517 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py @@ -34,11 +34,3 @@ def teardown_fabric_interface(mesh_device): logger.debug(f"Tearing down fabric (this may take a while if context switch interval is large)") ttnn.teardown_edm_fabric(mesh_device) ttnn.synchronize_device(mesh_device) - - -def create_global_semaphore_with_same_address(mesh_device, cores, initial_value): - semaphore_handles = ttnn.create_global_semaphore_with_same_address(mesh_device, cores, initial_value) - addrs = ttnn.get_global_semaphore_address(semaphore_handles) - # assert all addresses are the same - assert len(set(addrs)) == 1 - return semaphore_handles diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py index 113701493e2..8ebfcf3d490 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py @@ -11,7 +11,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import ( @@ -177,9 +176,7 @@ def run_all_gather_impl( mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - ccl_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) - ] + ccl_semaphore_handles = [ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)] logger.info(f"Output shape: {output_shape}") logger.info(f"dim: {dim}") diff --git a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py index 5cc3df68615..ab1b6312d94 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py @@ -13,7 +13,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) from tests.tt_eager.python_api_testing.unit_testing.misc.test_matmul_1d_gather_in0 import ( @@ -88,7 +87,7 @@ def run_all_reduce_impl( # create global semaphore handles num_buffers = 8 ccl_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_buffers) + ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_buffers) ] logger.info(f"Output shape: {output_shape}") diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 14298a6f5fb..494cd5579c5 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -11,7 +11,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) from ttnn import ShardTensor2dMesh, ConcatMesh2dToTensor @@ -172,8 +171,8 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( sub_device_stall_group = [worker_sub_device_id] mesh_device.set_sub_device_stall_group(sub_device_stall_group) # create global semaphore handles - from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) - to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) + from_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) + to_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) else: worker_sub_device_id = None ## diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py index 0623f56057d..64490f5d9c3 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async.py @@ -14,7 +14,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) @@ -168,10 +167,10 @@ def run_reduce_scatter_test( # create global semaphore handles from_remote_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) + ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) ] to_remote_semaphore_handles = [ - create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) + ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters) ] mesh_device.set_sub_device_stall_group([worker_sub_device_id]) debug = False diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py index d7ff05200d0..66d50a3a739 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_async_TG_nightly.py @@ -14,7 +14,6 @@ from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import ( create_and_load_sub_device_manager_with_fabric_interface, teardown_fabric_interface, - create_global_semaphore_with_same_address, ) diff --git a/tests/ttnn/unit_tests/test_global_semaphore.py b/tests/ttnn/unit_tests/test_global_semaphore.py index 045c90ef314..4015b5840bb 100644 --- a/tests/ttnn/unit_tests/test_global_semaphore.py +++ b/tests/ttnn/unit_tests/test_global_semaphore.py @@ -41,34 +41,3 @@ def test_global_semaphore(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) def test_global_semaphore_mesh(mesh_device, enable_async_mode): run_global_semaphore(mesh_device) - - -def run_global_semaphore_same_address(mesh_device): - tensix_cores0 = ttnn.CoreRangeSet( - { - ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(3, 3), - ), - } - ) - global_sem0 = ttnn.create_global_semaphore(mesh_device.get_devices()[0], tensix_cores0, 0) - global_sem1 = ttnn.create_global_semaphore(mesh_device.get_devices()[1], tensix_cores0, 0) - global_sem2 = ttnn.create_global_semaphore(mesh_device.get_devices()[0], tensix_cores0, 0) - - global_sem3 = ttnn.create_global_semaphore_with_same_address( - mesh_device, tensix_cores0, 0, attempts=10, search_max=False - ) - addrs0 = ttnn.get_global_semaphore_address(global_sem0) - addrs1 = ttnn.get_global_semaphore_address(global_sem1) - addrs2 = ttnn.get_global_semaphore_address(global_sem2) - addrs3 = ttnn.get_global_semaphore_address(global_sem3) - logger.debug(f"addrs0: {addrs0}, addrs1: {addrs1}, addrs2: {addrs2}, addrs3: {addrs3}") - assert len(set(addrs3)) == 1 - - -@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -def test_global_semaphore_mesh_same_address(mesh_device, enable_async_mode): - if len(mesh_device.get_devices()) < 4: - pytest.skip("requires at least 4 devices to run") - run_global_semaphore_same_address(mesh_device) diff --git a/ttnn/cpp/pybind11/global_semaphore.cpp b/ttnn/cpp/pybind11/global_semaphore.cpp index bdc7a2d977b..a99ddafd120 100644 --- a/ttnn/cpp/pybind11/global_semaphore.cpp +++ b/ttnn/cpp/pybind11/global_semaphore.cpp @@ -13,11 +13,10 @@ namespace ttnn::global_semaphore { void py_module_types(py::module& module) { py::class_>(module, "global_sempahore"); - py::class_(module, "multi_device_global_semaphore"); } void py_module(py::module& module) { - // Single Device APIs + // Single Device Creation API module.def( "create_global_semaphore", py::overload_cast( @@ -36,31 +35,7 @@ void py_module(py::module& module) { buffer_type (BufferType): The type of buffer to use for the global semaphore. )doc"); - module.def( - "get_global_semaphore_address", - py::overload_cast(&get_global_semaphore_address), - py::arg("global_semaphore"), - R"doc( - Get the address of the global semaphore. - - Args: - global_semaphore (GlobalSemaphore): The global semaphore object. - )doc"); - - module.def( - "reset_global_semaphore_value", - py::overload_cast(&reset_global_semaphore_value), - py::arg("global_semaphore"), - py::arg("reset_value"), - R"doc( - Reset the value of the global semaphore. - - Args: - global_semaphore (GlobalSemaphore): The global semaphore object. - reset_value (int): The value to reset the global semaphore to. - )doc"); - - // Multi Device APIs + // MeshDevice Creation API module.def( "create_global_semaphore", py::overload_cast( @@ -79,32 +54,9 @@ void py_module(py::module& module) { buffer_type (BufferType): The type of buffer to use for the global semaphore. )doc"); - module.def( - "create_global_semaphore_with_same_address", - &ttnn::global_semaphore::create_global_semaphore_with_same_address, - py::arg("mesh_device"), - py::arg("cores"), - py::arg("initial_value"), - py::arg("buffer_type") = tt::tt_metal::BufferType::L1, - py::arg("attempts") = 1000, - py::arg("search_max") = false, - R"doc( - Create a GlobalSemaphore Object on multiple devices with the same address by iteratively creating global semaphore until alignment is found. - Fails if the address is not the same on all devices after the specified number of attempts. - Note: Temperary API until mesh allocator is implemented. - - Args: - mesh_device (MeshDevice): The mesh device on which to create the global semaphore. - cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization. - initial_value (int): The initial value of the global semaphore. - buffer_type (BufferType): The type of buffer to use for the global semaphore. - attempts (int): The number of attempts to create the global semaphore with the same address. - search_max (bool): Whether to search for the maximum address. (default: False, which searches for the minimum address) - )doc"); - module.def( "get_global_semaphore_address", - py::overload_cast(&get_global_semaphore_address), + py::overload_cast(&get_global_semaphore_address), py::arg("global_semaphore"), R"doc( Get the address of the global semaphore. @@ -115,7 +67,7 @@ void py_module(py::module& module) { module.def( "reset_global_semaphore_value", - py::overload_cast(&reset_global_semaphore_value), + py::overload_cast(&reset_global_semaphore_value), py::arg("global_semaphore"), py::arg("reset_value"), R"doc( diff --git a/ttnn/cpp/ttnn/global_semaphore.cpp b/ttnn/cpp/ttnn/global_semaphore.cpp index 4b462b60646..0edb3a0d6b4 100644 --- a/ttnn/cpp/ttnn/global_semaphore.cpp +++ b/ttnn/cpp/ttnn/global_semaphore.cpp @@ -10,18 +10,16 @@ namespace ttnn::global_semaphore { -MultiDeviceGlobalSemaphore::MultiDeviceGlobalSemaphore(MeshDevice* mesh_device) { - TT_ASSERT( - mesh_device != nullptr, - "Must provide a valid mesh_device when initializing a global semaphore on multiple devices."); - this->global_semaphores.reserve(mesh_device->num_devices()); -} - GlobalSemaphore create_global_semaphore( IDevice* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { return CreateGlobalSemaphore(device, cores, initial_value, buffer_type); } +GlobalSemaphore create_global_semaphore( + MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { + return CreateGlobalSemaphore(mesh_device, cores, initial_value, buffer_type); +} + tt::tt_metal::DeviceAddr get_global_semaphore_address(const GlobalSemaphore& global_semaphore) { return global_semaphore.address(); } @@ -30,111 +28,4 @@ void reset_global_semaphore_value(const GlobalSemaphore& global_semaphore, uint3 global_semaphore.reset_semaphore_value(reset_value); } -MultiDeviceGlobalSemaphore create_global_semaphore( - MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type) { - MultiDeviceGlobalSemaphore multi_device_global_semaphore(mesh_device); - auto& global_semaphores = multi_device_global_semaphore.global_semaphores; - const auto& devices = mesh_device->get_devices(); - for (uint32_t i = 0; i < devices.size(); ++i) { - auto* device = devices[i]; - global_semaphores.push_back(create_global_semaphore(device, cores, initial_value, buffer_type)); - } - return multi_device_global_semaphore; -} -MultiDeviceGlobalSemaphore create_global_semaphore_with_same_address( - MeshDevice* mesh_device, - const CoreRangeSet& cores, - uint32_t initial_value, - BufferType buffer_type, - uint32_t attempts, - bool search_max) { - MultiDeviceGlobalSemaphore multi_device_global_semaphore(mesh_device); - const auto& devices = mesh_device->get_devices(); - for (uint32_t i = 0; i < devices.size(); ++i) { - auto* device = devices[i]; - multi_device_global_semaphore.global_semaphores.push_back( - create_global_semaphore(device, cores, initial_value, buffer_type)); - } - - auto global_semaphores = multi_device_global_semaphore.global_semaphores; - auto first_addr = get_global_semaphore_address(global_semaphores.front()); - bool all_same = std::all_of(global_semaphores.begin(), global_semaphores.end(), [first_addr](const auto& sem) { - return get_global_semaphore_address(sem) == first_addr; - }); - - if (!all_same) { - tt::log_debug("chkpt 1, attempts: {}", attempts); - tt::tt_metal::DeviceAddr target_addr = get_global_semaphore_address(global_semaphores.front()); - for (auto i = 1; i < global_semaphores.size(); i++) { - tt::log_debug( - "chkpt 1.1, i: {}, global_semaphores[i]->address(): {}", - i, - get_global_semaphore_address(global_semaphores[i])); - if (search_max) { - if (get_global_semaphore_address(global_semaphores[i]) > target_addr) { - target_addr = get_global_semaphore_address(global_semaphores[i]); - } - } else { - if (get_global_semaphore_address(global_semaphores[i]) < target_addr) { - target_addr = get_global_semaphore_address(global_semaphores[i]); - } - } - }; - tt::log_debug("chkpt 2, target_addr: {}", target_addr); - for (auto i = 0; i < global_semaphores.size(); i++) { - auto* device = devices[i]; - tt::log_debug("pushed, i: {}", i); - device->push_work([i, - device, - attempts, - target_addr, - &cores, - initial_value, - buffer_type, - global_semaphore = &multi_device_global_semaphore.global_semaphores[i]] { - size_t attempt = 0; - std::vector garbage; - tt::log_debug("global_semaphore->address(): {}", get_global_semaphore_address(*global_semaphore)); - while (get_global_semaphore_address(*global_semaphore) != target_addr) { - auto sem = create_global_semaphore(device, cores, initial_value, buffer_type); - - if (i == 0) { - tt::log_debug("chkpt 3, sem->address(): {}", get_global_semaphore_address(sem)); - } - - if (get_global_semaphore_address(sem) == target_addr) { - *global_semaphore = std::move(sem); - } else { - garbage.push_back(std::move(sem)); - attempt++; - } - - if (attempt > attempts) { - TT_THROW("Failed to create global semaphores with the same address"); - } - } - }); - } - for (auto device : devices) { - device->synchronize(); - } - } - - return multi_device_global_semaphore; -} -std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore) { - std::vector addresses(global_semaphore.global_semaphores.size()); - const auto& global_semaphores = global_semaphore.global_semaphores; - for (uint32_t i = 0; i < global_semaphores.size(); ++i) { - addresses[i] = get_global_semaphore_address(global_semaphores[i]); - } - return addresses; -} - -void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore, uint32_t reset_value) { - for (const auto& global_semaphore : global_semaphore.global_semaphores) { - reset_global_semaphore_value(global_semaphore, reset_value); - } -} - } // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/ttnn/global_semaphore.hpp b/ttnn/cpp/ttnn/global_semaphore.hpp index 87fa50c5529..3d02bba01ae 100644 --- a/ttnn/cpp/ttnn/global_semaphore.hpp +++ b/ttnn/cpp/ttnn/global_semaphore.hpp @@ -9,37 +9,19 @@ namespace ttnn::global_semaphore { -struct MultiDeviceGlobalSemaphore { - MultiDeviceGlobalSemaphore(MeshDevice* mesh_device); - std::vector global_semaphores; - - static constexpr auto attribute_names = std::forward_as_tuple("global_semaphores"); - const auto attribute_values() const { return std::forward_as_tuple(this->global_semaphores); } -}; - -// Single Device APIs +// Single Device Creation API GlobalSemaphore create_global_semaphore( IDevice* device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); -tt::tt_metal::DeviceAddr get_global_semaphore_address(const GlobalSemaphore& global_semaphore); - -void reset_global_semaphore_value(const GlobalSemaphore& global_semaphore, uint32_t reset_value); - -// Multi Device APIs -MultiDeviceGlobalSemaphore create_global_semaphore( +// MeshDevice Creation API +GlobalSemaphore create_global_semaphore( MeshDevice* mesh_device, const CoreRangeSet& cores, uint32_t initial_value, BufferType buffer_type = BufferType::L1); -MultiDeviceGlobalSemaphore create_global_semaphore_with_same_address( - MeshDevice* mesh_device, - const CoreRangeSet& cores, - uint32_t initial_value, - BufferType buffer_type, - uint32_t attempts, - bool search_max = false); -std::vector get_global_semaphore_address(const MultiDeviceGlobalSemaphore& global_semaphore); -void reset_global_semaphore_value(const MultiDeviceGlobalSemaphore& global_semaphore, uint32_t reset_value); +tt::tt_metal::DeviceAddr get_global_semaphore_address(const GlobalSemaphore& global_semaphore); + +void reset_global_semaphore_value(const GlobalSemaphore& global_semaphore, uint32_t reset_value); } // namespace ttnn::global_semaphore diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp index 67dbaa9b9f5..8f9d6314dae 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.cpp @@ -13,7 +13,7 @@ namespace ttnn::operations::experimental::ccl { ttnn::Tensor ExecuteAllGatherAsync::invoke( const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const uint32_t num_links, const std::optional& memory_config, const ttnn::ccl::Topology topology, @@ -36,7 +36,7 @@ ttnn::Tensor ExecuteAllGatherAsync::invoke( const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, const std::optional num_preferred_links, std::optional subdevice_id, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp index b6da224de97..2adc5d0b6f6 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async.hpp @@ -15,7 +15,7 @@ struct ExecuteAllGatherAsync { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring, @@ -28,7 +28,7 @@ struct ExecuteAllGatherAsync { const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config = std::nullopt, const std::optional num_preferred_links = std::nullopt, std::optional subdevice_id = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp index 8e1ab8c48f7..15482f031e1 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/all_gather_async_pybind.cpp @@ -29,7 +29,7 @@ void bind_all_gather_async(pybind11::module& module, const ccl_operation_t& oper [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const uint32_t num_links, const std::optional& memory_config, const ttnn::ccl::Topology topology, @@ -62,7 +62,7 @@ void bind_all_gather_async(pybind11::module& module, const ccl_operation_t& oper const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional num_preferred_links, const std::optional& memory_config, std::optional subdevice_id, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp index 1dc8c1cae46..30e4b7a17c9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.cpp @@ -20,19 +20,17 @@ AllGatherAsync create_all_gather_async_struct( const std::optional& memory_config, const std::vector& devices, const ttnn::ccl::Topology topology, - const std::vector& semaphores, + const GlobalSemaphore& semaphore, std::optional sub_device_id, bool enable_persistent_fabric_mode) { uint32_t num_devices = devices.size(); std::optional forward_device = std::nullopt; std::optional backward_device = std::nullopt; - std::optional semaphore = std::nullopt; uint32_t device_index = 0; // Initialize device index for (uint32_t i = 0; i < num_devices; ++i) { if (devices.at(i) == input_tensor.device()) { device_index = i; - semaphore = semaphores.at(i); // Get raw pointer if (i != 0) { backward_device = devices.at(i - 1); } @@ -51,7 +49,7 @@ AllGatherAsync create_all_gather_async_struct( device_index, memory_config.value_or(input_tensor.memory_config()), topology, - semaphore.value(), + semaphore, sub_device_id, enable_persistent_fabric_mode}; } @@ -305,7 +303,7 @@ namespace ccl { Tensor all_gather_async( const Tensor& input_tensor, const uint32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const uint32_t num_links, const std::optional& memory_config, const ttnn::ccl::Topology topology, @@ -332,8 +330,6 @@ Tensor all_gather_async( CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::vector semaphores = multi_device_global_semaphore.global_semaphores; - tt::tt_metal::operation::launch_op( [dim, num_links, @@ -341,7 +337,7 @@ Tensor all_gather_async( memory_config, devices, ccl_topology, - semaphores, + multi_device_global_semaphore, sub_device_id, enable_persistent_fabric_mode]( const std::vector& input_tensors, @@ -357,7 +353,7 @@ Tensor all_gather_async( memory_config, devices, ccl_topology, - semaphores, + multi_device_global_semaphore, sub_device_id, enable_persistent_fabric_mode), {input_tensor}); @@ -373,7 +369,7 @@ Tensor all_gather_async( const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, const std::optional num_preferred_links, std::optional sub_device_id, @@ -399,7 +395,6 @@ Tensor all_gather_async( std::vector output_tensors = {Tensor(tt::tt_metal::operation::get_workers_for_op_output({input_tensor}))}; CoreCoord grid_size = devices[0]->compute_with_storage_grid_size(); auto core_grid = CoreRange({0, 0}, {grid_size.x - 1, grid_size.y - 1}); - std::vector semaphores = multi_device_global_semaphore.global_semaphores; tt::tt_metal::operation::launch_op( [gather_dim, @@ -409,7 +404,7 @@ Tensor all_gather_async( cluster_axis, num_devices, topology, - semaphores, + multi_device_global_semaphore, sub_device_id, enable_persistent_fabric_mode]( const std::vector& input_tensors, @@ -434,7 +429,7 @@ Tensor all_gather_async( memory_config, devices, topology, - semaphores, + multi_device_global_semaphore, sub_device_id, enable_persistent_fabric_mode), {input_tensor}); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp index bd66bd923aa..3a519d683c3 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_async/device/all_gather_async_op.hpp @@ -162,7 +162,7 @@ namespace ccl { Tensor all_gather_async( const Tensor& input_tensor, const uint32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const ttnn::ccl::Topology topology = ttnn::ccl::Topology::Ring, @@ -175,7 +175,7 @@ Tensor all_gather_async( const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config = std::nullopt, const std::optional num_preferred_links = std::nullopt, std::optional sub_device_id = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp index b7c33163ab0..b476469ff42 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.cpp @@ -34,9 +34,9 @@ uint32_t find_scatter_dim(const ttnn::Shape& input_tensor_padded_shape, size_t n ttnn::Tensor ExecuteAllReduceAsync::invoke( const ttnn::Tensor& input_tensor, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, @@ -69,9 +69,9 @@ ttnn::Tensor ExecuteAllReduceAsync::invoke( const ttnn::Tensor& input_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, @@ -112,7 +112,7 @@ ttnn::Tensor ExecuteAllReduceAsync::invoke( ttnn::Tensor& buffer_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, ttnn::ccl::Topology topology, const std::optional num_preferred_links, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp index a49f2afcda2..8553afc5b6a 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async.hpp @@ -20,9 +20,9 @@ namespace ccl { struct ExecuteAllReduceAsync { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config = std::nullopt, ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, @@ -33,9 +33,9 @@ struct ExecuteAllReduceAsync { const ttnn::Tensor& input_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, @@ -47,7 +47,7 @@ struct ExecuteAllReduceAsync { ttnn::Tensor& buffer_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, ttnn::ccl::Topology topology, const std::optional num_preferred_links, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp index 886b1a3e8a6..8741bd288e9 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/all_reduce_async_pybind.cpp @@ -27,9 +27,9 @@ void bind_all_reduce_async(pybind11::module& module, const ccl_operation_t& oper ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, @@ -62,9 +62,9 @@ void bind_all_reduce_async(pybind11::module& module, const ccl_operation_t& oper const ttnn::Tensor& input_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& gather_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& gather_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, @@ -102,7 +102,7 @@ void bind_all_reduce_async(pybind11::module& module, const ccl_operation_t& oper ttnn::Tensor& buffer_tensor, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, const std::optional num_links, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp index d6d9d5b17a3..13bbe7943e4 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.cpp @@ -20,19 +20,17 @@ AllReduceAsync create_all_reduce_async_struct( const std::optional& memory_config, const std::vector& devices, const ttnn::ccl::Topology topology, - const std::vector& semaphores, + const GlobalSemaphore& semaphore, std::optional& sub_device_id, bool enable_persistent_fabric_mode) { uint32_t num_devices = devices.size(); std::optional forward_device = std::nullopt; std::optional backward_device = std::nullopt; - std::optional semaphore = std::nullopt; uint32_t device_index = 0; // Initialize device index for (uint32_t i = 0; i < num_devices; ++i) { if (devices.at(i) == input_tensor.device()) { device_index = i; - semaphore = semaphores.at(i); // Get raw pointer if (i != 0) { backward_device = devices.at(i - 1); } @@ -50,7 +48,7 @@ AllReduceAsync create_all_reduce_async_struct( device_index, memory_config.value_or(input_tensor.memory_config()), topology, - semaphore.value(), + semaphore, sub_device_id, enable_persistent_fabric_mode}; } @@ -187,7 +185,7 @@ Tensor all_reduce_async( const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config, const std::optional num_preferred_links, std::optional subdevice_id, @@ -200,7 +198,6 @@ Tensor all_reduce_async( std::size_t num_devices = (cluster_axis == 0) ? mesh_view.num_rows() : mesh_view.num_cols(); std::vector output_tensors = {Tensor(tt::tt_metal::operation::get_workers_for_op_output({input_tensor}))}; - std::vector semaphores = multi_device_global_semaphore.global_semaphores; tt::tt_metal::operation::launch_op( [num_preferred_links, @@ -209,7 +206,7 @@ Tensor all_reduce_async( cluster_axis, num_devices, topology, - semaphores, + multi_device_global_semaphore, subdevice_id, enable_persistent_fabric_mode]( const std::vector& input_tensors, @@ -234,7 +231,7 @@ Tensor all_reduce_async( memory_config, devices, topology, - semaphores, + multi_device_global_semaphore, subdevice_id, enable_persistent_fabric_mode), {input_tensor, buffer_tensor}); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp index dae83a6bbc4..8297d541a95 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce_async/device/all_reduce_async_op.hpp @@ -128,7 +128,7 @@ Tensor all_reduce_async( const uint32_t cluster_axis, const MeshDevice& mesh_device, const ttnn::ccl::Topology topology, - const global_semaphore::MultiDeviceGlobalSemaphore& multi_device_global_semaphore, + const GlobalSemaphore& multi_device_global_semaphore, const std::optional& memory_config = std::nullopt, const std::optional num_preferred_links = std::nullopt, std::optional sub_device_id = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp index b655c5e8504..f31c2e8fead 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.cpp @@ -27,8 +27,8 @@ ReduceScatterAsync create_reduce_scatter_struct( std::optional> forward_output_tensors, std::optional> backward_output_tensors, std::optional num_links_preferred, - const std::vector& from_remote_sems, - const std::vector& to_remote_sems, + const GlobalSemaphore& from_remote_sem, + const GlobalSemaphore& to_remote_sem, std::optional sub_device_id, std::optional& fabric_handle) { uint32_t num_devices = devices.size(); @@ -56,9 +56,6 @@ ReduceScatterAsync create_reduce_scatter_struct( return *device; }; - GlobalSemaphore from_remote_sem = from_remote_sems.at(device_index); - GlobalSemaphore to_remote_sem = to_remote_sems.at(device_index); - return ttnn::ReduceScatterAsync{ binary_op_type, scatter_dim, @@ -198,8 +195,8 @@ namespace ccl { Tensor reduce_scatter( const Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const MemoryConfig& output_mem_config, ttnn::ccl::Topology topology, @@ -228,12 +225,6 @@ Tensor reduce_scatter( rank - 1, dim); - std::vector from_remote_inputs_semaphores = - from_remote_multi_device_global_semaphore.global_semaphores; - - std::vector to_remote_inputs_semaphores = - to_remote_multi_device_global_semaphore.global_semaphores; - std::vector output_tensors = { Tensor(operation::get_workers_for_op_output({input_tensor})), Tensor(operation::get_workers_for_op_output({input_tensor})), @@ -245,8 +236,8 @@ Tensor reduce_scatter( "Reduce scatter requires 5 output tensors. 1 is real and the others are temporaries"); operation::launch_op( [binary_op_type, - from_remote_inputs_semaphores, - to_remote_inputs_semaphores, + from_remote_multi_device_global_semaphore, + to_remote_multi_device_global_semaphore, scatter_dim, output_mem_config, ccl_topology, @@ -271,8 +262,8 @@ Tensor reduce_scatter( std::nullopt, std::nullopt, num_links_preferred, - from_remote_inputs_semaphores, - to_remote_inputs_semaphores, + from_remote_multi_device_global_semaphore, + from_remote_multi_device_global_semaphore, worker_subdevice_id_opt, fabric_handle), {input_tensor}); @@ -287,8 +278,8 @@ Tensor reduce_scatter( const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType reduce_op, const MemoryConfig& output_mem_config, ttnn::ccl::Topology topology, @@ -303,12 +294,6 @@ Tensor reduce_scatter( const auto mesh_view = mesh_device.get_view(); auto devices = input_tensor.get_workers(); - std::vector from_remote_inputs_semaphores = - from_remote_multi_device_global_semaphore.global_semaphores; - - std::vector to_remote_inputs_semaphores = - to_remote_multi_device_global_semaphore.global_semaphores; - std::vector output_tensors = { Tensor(operation::get_workers_for_op_output({input_tensor})), Tensor(operation::get_workers_for_op_output({input_tensor})), @@ -320,8 +305,8 @@ Tensor reduce_scatter( "Reduce scatter requires 5 output tensors. 1 is real and the others are temporaries"); operation::launch_op( [binary_op_type, - from_remote_inputs_semaphores, - to_remote_inputs_semaphores, + from_remote_multi_device_global_semaphore, + to_remote_multi_device_global_semaphore, scatter_dim, output_mem_config, mesh_view, @@ -357,8 +342,8 @@ Tensor reduce_scatter( std::nullopt, std::nullopt, num_links_preferred, - from_remote_inputs_semaphores, - to_remote_inputs_semaphores, + from_remote_multi_device_global_semaphore, + to_remote_multi_device_global_semaphore, worker_subdevice_id_opt, fabric_handle), {input_tensor}); diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp index 1911a6a1160..a4257b94494 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/device/reduce_scatter_async_op.hpp @@ -136,8 +136,8 @@ namespace ccl { Tensor reduce_scatter( const Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, const MemoryConfig& output_mem_config = tt::tt_metal::operation::DEFAULT_OUTPUT_MEMORY_CONFIG, ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, @@ -149,8 +149,8 @@ Tensor reduce_scatter( const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, const MemoryConfig& output_mem_config = tt::tt_metal::operation::DEFAULT_OUTPUT_MEMORY_CONFIG, ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp index e9f66e1645b..da78145726c 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.cpp @@ -12,8 +12,8 @@ namespace ttnn::operations::experimental::ccl { ttnn::Tensor ExecuteReduceScatter::invoke( const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, @@ -37,8 +37,8 @@ ttnn::Tensor ExecuteReduceScatter::invoke( const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp index 42de61e83cc..2cf0b9b21e7 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter.hpp @@ -21,8 +21,8 @@ struct ExecuteReduceScatter { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config = std::nullopt, ttnn::ccl::Topology topology = ttnn::ccl::Topology::Linear, @@ -34,8 +34,8 @@ struct ExecuteReduceScatter { const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const std::optional& memory_config, ttnn::ccl::Topology topology, diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp index 8ea38fab03b..054534e8c2b 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/reduce_scatter_async/reduce_scatter_pybind.cpp @@ -28,8 +28,8 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, const int32_t dim, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, @@ -63,8 +63,8 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, - const global_semaphore::MultiDeviceGlobalSemaphore& from_remote_multi_device_global_semaphore, - const global_semaphore::MultiDeviceGlobalSemaphore& to_remote_multi_device_global_semaphore, + const GlobalSemaphore& from_remote_multi_device_global_semaphore, + const GlobalSemaphore& to_remote_multi_device_global_semaphore, ttnn::operations::reduction::ReduceType math_op, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 5696e794b1b..ea793bbfc55 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -124,7 +124,6 @@ def manage_config(name, value): create_global_semaphore, get_global_semaphore_address, reset_global_semaphore_value, - create_global_semaphore_with_same_address, ) from ttnn.types import (