Skip to content

Commit

Permalink
Add webgpu backend
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Feb 2, 2025
1 parent 2d8e667 commit 52db1fa
Show file tree
Hide file tree
Showing 14 changed files with 1,034 additions and 10 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
endif()
endif()

if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
endif()

else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
Expand Down Expand Up @@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(MLX_BUILD_WEBGPU)
FetchContent_Declare(
betann
GIT_REPOSITORY https://github.com/frost-beta/betann.git
GIT_TAG db8cb414c81d05cbdb6827637733f6087e4d2049
EXCLUDE_FROM_ALL)
set(BETANN_BUILD_TESTS OFF)
FetchContent_MakeAvailable(betann)
target_link_libraries(mlx PRIVATE betann)
endif()

if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
Expand Down
2 changes: 2 additions & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)

if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
elseif(MLX_BUILD_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()
8 changes: 4 additions & 4 deletions mlx/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ size_t CommonAllocator::size(Buffer buffer) const {
return *static_cast<size_t*>(buffer.ptr());
}

Buffer malloc_or_wait(size_t size) {
auto buffer = allocator().malloc(size);
Buffer malloc_or_wait(const Device& device, size_t size) {
auto buffer = allocator().malloc(device, size);

while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) {
scheduler::wait_for_one();
buffer = allocator().malloc(size);
buffer = allocator().malloc(device, size);
}

// Try swapping if needed
if (size && !buffer.ptr()) {
buffer = allocator().malloc(size, /* allow_swap = */ true);
buffer = allocator().malloc(device, size, /* allow_swap = */ true);
}

if (size && !buffer.ptr()) {
Expand Down
11 changes: 10 additions & 1 deletion mlx/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <cstdlib>

#include "mlx/device.h"

namespace mlx::core::allocator {

// Simple wrapper around buffer pointers
Expand Down Expand Up @@ -34,12 +36,19 @@ void free(Buffer buffer);

// Wait for running tasks to finish and free up memory
// if allocation fails
Buffer malloc_or_wait(size_t size);
Buffer malloc_or_wait(const Device& device, size_t size);
inline Buffer malloc_or_wait(size_t size) {
return malloc_or_wait(Device::cpu, size);
}

class Allocator {
/** Abstract base class for a memory allocator. */
public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual Buffer
malloc(const Device& device, size_t size, bool allow_swap = false) {
return malloc(size, allow_swap);
}
virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;

Expand Down
9 changes: 9 additions & 0 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,13 @@ bool array::is_tracer() const {
detail::retain_graph();
}

void array::reset_data_ptr() {
void* data_ptr = buffer().raw_ptr();
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}

void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
Expand Down Expand Up @@ -142,6 +149,7 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
Expand All @@ -161,6 +169,7 @@ void array::move_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
Expand Down
5 changes: 5 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;

void reset_data_ptr();

void set_data(allocator::Buffer buffer, Deleter d = allocator::free);

void set_data(
Expand Down Expand Up @@ -465,6 +467,9 @@ class array {
// The size in elements of the data buffer the array accesses
size_t data_size;

// Offset from the shared data in elements
size_t offset{0};

// Contains useful meta data about the array
Flags flags;

Expand Down
15 changes: 10 additions & 5 deletions mlx/backend/common/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mlx/allocator.h"
#include "mlx/array.h"
#include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"

#include "mlx/backend/common/simd/simd.h"

Expand Down Expand Up @@ -47,10 +48,14 @@ void set_binary_op_output_data(
bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
const Device& device = out.primitive().device();
switch (bopt) {
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
allocator::malloc_or_wait(device, out.itemsize()),
1,
a.strides(),
a.flags());
break;
case BinaryOpType::ScalarVector:
if (b_donatable) {
Expand All @@ -61,7 +66,7 @@ void set_binary_op_output_data(
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
allocator::malloc_or_wait(device, b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
Expand All @@ -76,7 +81,7 @@ void set_binary_op_output_data(
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc_or_wait(device, a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
Expand All @@ -97,7 +102,7 @@ void set_binary_op_output_data(
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
allocator::malloc_or_wait(device, a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
Expand All @@ -118,7 +123,7 @@ void set_binary_op_output_data(
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
out.set_data(allocator::malloc_or_wait(device, out.nbytes()));
}
break;
}
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)
119 changes: 119 additions & 0 deletions mlx/backend/webgpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/webgpu/allocator.h"

namespace mlx::core {

namespace allocator {

Allocator& allocator() {
return webgpu::allocator();
}

void* Buffer::raw_ptr() {
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
}

} // namespace allocator

namespace webgpu {

DoubleBuffer::DoubleBuffer(size_t size)
: cpu_data_(std::malloc(size + sizeof(size_t))) {
*static_cast<size_t*>(cpu_data_) = size;
}

DoubleBuffer::DoubleBuffer(betann::Device& device, size_t size)
: gpu_data_(device.CreateBuffer(
size,
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}

DoubleBuffer::~DoubleBuffer() {
std::free(cpu_data_);
}

void DoubleBuffer::copy_to_cpu(const void* data, size_t size) {
assert(!cpu_data_);
cpu_data_ = std::malloc(size + sizeof(size_t));
*static_cast<size_t*>(cpu_data_) = size;
std::memcpy(cpu_data(), data, size);
}

size_t DoubleBuffer::size() const {
if (cpu_data_)
return *static_cast<size_t*>(cpu_data_);
if (gpu_data_)
return gpu_data_.GetSize();
return 0;
}

WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}

Buffer
WgpuAllocator::malloc(const Device& device, size_t size, bool allow_swap) {
if (device.type == Device::gpu)
return Buffer(new DoubleBuffer(webgpu::device(device), size));
else
return Buffer(new DoubleBuffer(size));
}

void WgpuAllocator::free(Buffer buffer) {
delete static_cast<DoubleBuffer*>(buffer.ptr());
}

size_t WgpuAllocator::size(Buffer buffer) const {
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
}

void WgpuAllocator::ensure_gpu_data(Buffer& buffer) {
auto* dbuf = static_cast<DoubleBuffer*>(buffer.ptr());
if (dbuf->gpu_data() || dbuf->size() == 0)
return;
dbuf->set_gpu_data(device_.CreateBufferFromData(
dbuf->cpu_data(), dbuf->size(), betann::BufferUsage::Storage));
}

WgpuAllocator& allocator() {
static WgpuAllocator allocator_;
return allocator_;
}

betann::Device& device(mlx::core::Device) {
static betann::Device device;
return device;
}

} // namespace webgpu

namespace metal {

size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}

std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
throw std::runtime_error("[webgpu::device_info] Not implemented");
};

void clear_cache() {}

} // namespace metal

} // namespace mlx::core
Loading

0 comments on commit 52db1fa

Please sign in to comment.