diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index f5502231a..3482fe9a8 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -1,12 +1,39 @@ // Copyright © 2024 Apple Inc. -#include "mlx/fence.h" +#include + #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/metal_impl.h" +#include "mlx/fence.h" #include "mlx/scheduler.h" #include "mlx/utils.h" namespace mlx::core { +void signal_handler(int signum); + +MTL::Buffer* signal_buffer() { + auto init = []() { + signal(SIGTERM, signal_handler); + auto dtor = [](void* buf) { + allocator::free(static_cast(buf)); + }; + auto buf = std::shared_ptr( + allocator::malloc_or_wait(sizeof(uint32_t)).ptr(), dtor); + static_cast( + static_cast(buf.get())->contents())[0] = 0; + return buf; + }; + static std::shared_ptr buf = init(); + return static_cast(buf.get()); +} + +void signal_handler(int signum) { + auto buf = signal_buffer(); + static_cast(buf->contents())[0] = 1; + signal(signum, SIG_DFL); + raise(signum); +} + struct FenceImpl { FenceImpl() { auto d = metal::device(Device::gpu).mtl_device(); @@ -94,6 +121,7 @@ void Fence::wait(Stream stream, const array& x) { auto buf = static_cast(f.fence); compute_encoder.set_buffer(buf, 0); compute_encoder.set_bytes(f.count, 1); + compute_encoder.set_buffer(signal_buffer(), 2); compute_encoder.dispatch_threads(kernel_dims, kernel_dims); d.get_command_buffer(idx)->addCompletedHandler( diff --git a/mlx/backend/metal/kernels/fence.metal b/mlx/backend/metal/kernels/fence.metal index f736b33b0..0cf73686f 100644 --- a/mlx/backend/metal/kernels/fence.metal +++ b/mlx/backend/metal/kernels/fence.metal @@ -39,13 +39,14 @@ constexpr constant metal::thread_scope thread_scope_system = // single thread kernel to spin wait for timestamp value [[kernel]] void fence_wait( volatile coherent(system) device uint* timestamp [[buffer(0)]], - constant uint& value [[buffer(1)]]) { + constant uint& value [[buffer(1)]], + volatile coherent(system) device uint* sig_handler [[buffer(2)]]) { while (1) { metal::atomic_thread_fence( metal::mem_flags::mem_device, metal::memory_order_seq_cst, metal::thread_scope_system); - if (timestamp[0] >= value) { + if (timestamp[0] >= value || sig_handler[0] > 0) { break; } }