Skip to content

Commit

Permalink
Remove MultiDeviceGlobalSemaphore
Browse files Browse the repository at this point in the history
  - TT-Metal already supports Global Sems backed by MeshBuffer
  - Expose this concept to TTNN, and remove the concept of the
    MultiDeviceGlobalSemaphore
  - create_global_semaphore(mesh_device, ...) is now hooked to
    GlobalSemaphore
  - Remove create_global_semaphore_with_same_address, since global
    sem buffer allocation is now through the MeshAllocator
  • Loading branch information
tt-asaigal committed Mar 9, 2025
1 parent 2833697 commit 8f65743
Show file tree
Hide file tree
Showing 31 changed files with 127 additions and 383 deletions.
30 changes: 13 additions & 17 deletions tests/ttnn/unit_tests/gtests/ccl/test_fabric_edm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2029,14 +2029,12 @@ void run_all_gather_with_persistent_fabric(const size_t dim, const size_t num_li
num_links);
log_info(tt::LogTest, "Lauching op");

ttnn::global_semaphore::MultiDeviceGlobalSemaphore multi_device_global_semaphore =
ttnn::global_semaphore::create_global_semaphore_with_same_address(
test_fixture.mesh_device_.get(),
devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1, // buffer type
10 // attempts
);
GlobalSemaphore multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore(
test_fixture.mesh_device_.get(),
test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1 // buffer type
);

auto output_tensor = ttnn::operations::experimental::ccl::all_gather_async(
input_mesh_tensor,
Expand Down Expand Up @@ -2161,18 +2159,16 @@ void RunWriteThroughputStabilityTestWithPersistentFabric(

std::vector<tt::tt_metal::DeviceAddr> global_semaphore_addrs;
global_semaphore_addrs.reserve(line_size + 1);
std::vector<ttnn::global_semaphore::MultiDeviceGlobalSemaphore> global_semaphore_handles;
std::vector<GlobalSemaphore> global_semaphore_handles;
for (size_t i = 0; i < line_size * 4; i++) {
auto global_semaphores = ttnn::global_semaphore::create_global_semaphore_with_same_address(
auto global_semaphore = ttnn::global_semaphore::create_global_semaphore(
test_fixture.mesh_device_.get(),
devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1, // buffer type
1000 // attempts
test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1 // buffer type
);
global_semaphore_handles.push_back(global_semaphores);
auto global_semaphore_addr =
ttnn::global_semaphore::get_global_semaphore_address(global_semaphores.global_semaphores.at(0));
global_semaphore_handles.push_back(global_semaphore);
auto global_semaphore_addr = ttnn::global_semaphore::get_global_semaphore_address(global_semaphore);
global_semaphore_addrs.push_back(global_semaphore_addr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,23 +857,19 @@ TEST(CclAsyncOp, ReduceScatterSmall_PersistentFabric) {
enable_persistent_fabric,
num_links);

ttnn::global_semaphore::MultiDeviceGlobalSemaphore from_remote_multi_device_global_semaphore =
ttnn::global_semaphore::create_global_semaphore_with_same_address(
test_fixture.mesh_device_.get(),
devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1, // buffer type
10 // attempts
);

ttnn::global_semaphore::MultiDeviceGlobalSemaphore to_remote_multi_device_global_semaphore =
ttnn::global_semaphore::create_global_semaphore_with_same_address(
test_fixture.mesh_device_.get(),
devices[0]->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1, // buffer type
10 // attempts
);
GlobalSemaphore from_remote_multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore(
test_fixture.mesh_device_.get(),
test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1 // buffer type
);

GlobalSemaphore to_remote_multi_device_global_semaphore = ttnn::global_semaphore::create_global_semaphore(
test_fixture.mesh_device_.get(),
test_fixture.mesh_device_.get()->worker_cores(HalProgrammableCoreType::TENSIX, SubDeviceId{0}),
0, // initial value
tt::tt_metal::BufferType::L1 // buffer type
);

auto output_tensor = ttnn::operations::experimental::ccl::reduce_scatter(
input_mesh_tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)
from models.perf.benchmarking_utils import BenchmarkProfiler
from tracy import signpost
Expand Down Expand Up @@ -281,7 +280,7 @@ def run_line_all_gather_on_TG_with_mesh_tensor_along_rows(

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(NUM_BUFFERS)
]
try:
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)

from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import (
Expand Down
13 changes: 6 additions & 7 deletions tests/ttnn/unit_tests/operations/ccl/test_all_reduce_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)


Expand Down Expand Up @@ -63,9 +62,9 @@ def run_all_reduce_test(
sub_device_stall_group = [worker_sub_device_id]
mesh_device.set_sub_device_stall_group(sub_device_stall_group)
# create global semaphore handles
from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
gather_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
from_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
gather_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)

debug = False

Expand Down Expand Up @@ -341,9 +340,9 @@ def run_all_reduce_with_mesh_tensor_along_row(
sub_device_stall_group = [worker_sub_device_id]
mesh_device.set_sub_device_stall_group(sub_device_stall_group)
# create global semaphore handles
from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
gather_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
from_remote_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
gather_semaphore_handles = create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)

try:
debug = False
Expand Down
8 changes: 0 additions & 8 deletions tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,3 @@ def teardown_fabric_interface(mesh_device):
logger.debug(f"Tearing down fabric (this may take a while if context switch interval is large)")
ttnn.teardown_edm_fabric(mesh_device)
ttnn.synchronize_device(mesh_device)


def create_global_semaphore_with_same_address(mesh_device, cores, initial_value):
semaphore_handles = ttnn.create_global_semaphore_with_same_address(mesh_device, cores, initial_value)
addrs = ttnn.get_global_semaphore_address(semaphore_handles)
# assert all addresses are the same
assert len(set(addrs)) == 1
return semaphore_handles
5 changes: 1 addition & 4 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)

from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import (
Expand Down Expand Up @@ -177,9 +176,7 @@ def run_all_gather_impl(
mesh_device.set_sub_device_stall_group(sub_device_stall_group)

# create global semaphore handles
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
]
ccl_semaphore_handles = [ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)]

logger.info(f"Output shape: {output_shape}")
logger.info(f"dim: {dim}")
Expand Down
3 changes: 1 addition & 2 deletions tests/ttnn/unit_tests/operations/ccl/test_new_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)

from tests.tt_eager.python_api_testing.unit_testing.misc.test_matmul_1d_gather_in0 import (
Expand Down Expand Up @@ -88,7 +87,7 @@ def run_all_reduce_impl(
# create global semaphore handles
num_buffers = 8
ccl_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_buffers)
ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_buffers)
]

logger.info(f"Output shape: {output_shape}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)
from ttnn import ShardTensor2dMesh, ConcatMesh2dToTensor

Expand Down Expand Up @@ -172,8 +171,8 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
sub_device_stall_group = [worker_sub_device_id]
mesh_device.set_sub_device_stall_group(sub_device_stall_group)
# create global semaphore handles
from_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0)
from_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
to_remote_semaphore_handles = ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0)
else:
worker_sub_device_id = None
##
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)


Expand Down Expand Up @@ -168,10 +167,10 @@ def run_reduce_scatter_test(

# create global semaphore handles
from_remote_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
]
to_remote_semaphore_handles = [
create_global_semaphore_with_same_address(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
ttnn.create_global_semaphore(mesh_device, ccl_sub_device_crs, 0) for _ in range(num_iters)
]
mesh_device.set_sub_device_stall_group([worker_sub_device_id])
debug = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from tests.ttnn.unit_tests.operations.ccl.test_ccl_common import (
create_and_load_sub_device_manager_with_fabric_interface,
teardown_fabric_interface,
create_global_semaphore_with_same_address,
)


Expand Down
31 changes: 0 additions & 31 deletions tests/ttnn/unit_tests/test_global_semaphore.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,34 +41,3 @@ def test_global_semaphore(device, enable_async_mode):
@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_semaphore_mesh(mesh_device, enable_async_mode):
run_global_semaphore(mesh_device)


def run_global_semaphore_same_address(mesh_device):
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
),
}
)
global_sem0 = ttnn.create_global_semaphore(mesh_device.get_devices()[0], tensix_cores0, 0)
global_sem1 = ttnn.create_global_semaphore(mesh_device.get_devices()[1], tensix_cores0, 0)
global_sem2 = ttnn.create_global_semaphore(mesh_device.get_devices()[0], tensix_cores0, 0)

global_sem3 = ttnn.create_global_semaphore_with_same_address(
mesh_device, tensix_cores0, 0, attempts=10, search_max=False
)
addrs0 = ttnn.get_global_semaphore_address(global_sem0)
addrs1 = ttnn.get_global_semaphore_address(global_sem1)
addrs2 = ttnn.get_global_semaphore_address(global_sem2)
addrs3 = ttnn.get_global_semaphore_address(global_sem3)
logger.debug(f"addrs0: {addrs0}, addrs1: {addrs1}, addrs2: {addrs2}, addrs3: {addrs3}")
assert len(set(addrs3)) == 1


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_global_semaphore_mesh_same_address(mesh_device, enable_async_mode):
if len(mesh_device.get_devices()) < 4:
pytest.skip("requires at least 4 devices to run")
run_global_semaphore_same_address(mesh_device)
56 changes: 4 additions & 52 deletions ttnn/cpp/pybind11/global_semaphore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ namespace ttnn::global_semaphore {

void py_module_types(py::module& module) {
py::class_<GlobalSemaphore, std::shared_ptr<GlobalSemaphore>>(module, "global_sempahore");
py::class_<MultiDeviceGlobalSemaphore>(module, "multi_device_global_semaphore");
}

void py_module(py::module& module) {
// Single Device APIs
// Single Device Creation API
module.def(
"create_global_semaphore",
py::overload_cast<IDevice*, const CoreRangeSet&, uint32_t, BufferType>(
Expand All @@ -36,31 +35,7 @@ void py_module(py::module& module) {
buffer_type (BufferType): The type of buffer to use for the global semaphore.
)doc");

module.def(
"get_global_semaphore_address",
py::overload_cast<const GlobalSemaphore&>(&get_global_semaphore_address),
py::arg("global_semaphore"),
R"doc(
Get the address of the global semaphore.
Args:
global_semaphore (GlobalSemaphore): The global semaphore object.
)doc");

module.def(
"reset_global_semaphore_value",
py::overload_cast<const GlobalSemaphore&, uint32_t>(&reset_global_semaphore_value),
py::arg("global_semaphore"),
py::arg("reset_value"),
R"doc(
Reset the value of the global semaphore.
Args:
global_semaphore (GlobalSemaphore): The global semaphore object.
reset_value (int): The value to reset the global semaphore to.
)doc");

// Multi Device APIs
// MeshDevice Creation API
module.def(
"create_global_semaphore",
py::overload_cast<MeshDevice*, const CoreRangeSet&, uint32_t, BufferType>(
Expand All @@ -79,32 +54,9 @@ void py_module(py::module& module) {
buffer_type (BufferType): The type of buffer to use for the global semaphore.
)doc");

module.def(
"create_global_semaphore_with_same_address",
&ttnn::global_semaphore::create_global_semaphore_with_same_address,
py::arg("mesh_device"),
py::arg("cores"),
py::arg("initial_value"),
py::arg("buffer_type") = tt::tt_metal::BufferType::L1,
py::arg("attempts") = 1000,
py::arg("search_max") = false,
R"doc(
Create a GlobalSemaphore Object on multiple devices with the same address by iteratively creating global semaphore until alignment is found.
Fails if the address is not the same on all devices after the specified number of attempts.
Note: Temperary API until mesh allocator is implemented.
Args:
mesh_device (MeshDevice): The mesh device on which to create the global semaphore.
cores (CoreRangeSet): The cores on which the global semaphore will be used for synchronization.
initial_value (int): The initial value of the global semaphore.
buffer_type (BufferType): The type of buffer to use for the global semaphore.
attempts (int): The number of attempts to create the global semaphore with the same address.
search_max (bool): Whether to search for the maximum address. (default: False, which searches for the minimum address)
)doc");

module.def(
"get_global_semaphore_address",
py::overload_cast<const MultiDeviceGlobalSemaphore&>(&get_global_semaphore_address),
py::overload_cast<const GlobalSemaphore&>(&get_global_semaphore_address),
py::arg("global_semaphore"),
R"doc(
Get the address of the global semaphore.
Expand All @@ -115,7 +67,7 @@ void py_module(py::module& module) {

module.def(
"reset_global_semaphore_value",
py::overload_cast<const MultiDeviceGlobalSemaphore&, uint32_t>(&reset_global_semaphore_value),
py::overload_cast<const GlobalSemaphore&, uint32_t>(&reset_global_semaphore_value),
py::arg("global_semaphore"),
py::arg("reset_value"),
R"doc(
Expand Down
Loading

0 comments on commit 8f65743

Please sign in to comment.