Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experiment] WebGPU backend #1789

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -65,6 +66,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
endif()
endif()

if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
endif()

else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
Expand Down Expand Up @@ -128,6 +133,17 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(MLX_BUILD_WEBGPU)
FetchContent_Declare(
betann
GIT_REPOSITORY https://github.com/frost-beta/betann.git
GIT_TAG 8aa2701caf63fb29bd4cd2454e656973342c1588
EXCLUDE_FROM_ALL)
set(BETANN_BUILD_TESTS OFF)
FetchContent_MakeAvailable(betann)
target_link_libraries(mlx PRIVATE betann)
endif()

if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
Expand Down
2 changes: 2 additions & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)

if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
elseif(MLX_BUILD_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()
9 changes: 9 additions & 0 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ bool array::is_tracer() const {
detail::retain_graph();
}

void array::reset_data_ptr() {
void* data_ptr = buffer().raw_ptr();
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}

void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
Expand Down Expand Up @@ -165,6 +172,7 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
Expand All @@ -184,6 +192,7 @@ void array::move_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
Expand Down
11 changes: 11 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ class array {
return array_desc_->data_size;
}

/** The offset (in elements) of the underlying buffer the array points to.
**/
size_t offset() const {
return array_desc_->offset;
}

allocator::Buffer& buffer() {
return array_desc_->data->buffer;
}
Expand Down Expand Up @@ -413,6 +419,8 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;

void reset_data_ptr();

void set_data(allocator::Buffer buffer, Deleter d = allocator::free);

void set_data(
Expand Down Expand Up @@ -477,6 +485,9 @@ class array {
// The size in elements of the data buffer the array accesses
size_t data_size;

// Offset from the shared data in elements
size_t offset{0};

// Contains useful meta data about the array
Flags flags;

Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)
199 changes: 199 additions & 0 deletions mlx/backend/webgpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/webgpu/allocator.h"

#include "mlx/array.h"
#include "mlx/backend/webgpu/utils.h"
#include "mlx/primitives.h"

namespace mlx::core {

namespace allocator {

Allocator& allocator() {
return webgpu::allocator();
}

void* Buffer::raw_ptr() {
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
}

} // namespace allocator

namespace webgpu {

DoubleBuffer::DoubleBuffer(size_t size)
: size_(size), cpu_data_(std::malloc(size)) {}

DoubleBuffer::DoubleBuffer(betann::Device& device, Dtype dtype, size_t size)
: size_(size),
gpu_data_(device.CreateBuffer(
size * gpu_size_factor(dtype),
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}

DoubleBuffer::~DoubleBuffer() {
std::free(cpu_data_);
}

WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}

Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) {
return Buffer(new DoubleBuffer(size));
}

void WgpuAllocator::free(Buffer buffer) {
delete static_cast<DoubleBuffer*>(buffer.ptr());
}

size_t WgpuAllocator::size(Buffer buffer) const {
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
}

void WgpuAllocator::ensure_cpu_data(array& arr, const void* data) {
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
if (dbuf->cpu_data() || dbuf->size() == 0)
return;
void* cpu_data = std::malloc(dbuf->size());
size_t num_elements = dbuf->size() / arr.itemsize();
switch (arr.dtype()) {
case int32:
case uint32:
case float16:
case float32:
std::memcpy(cpu_data, data, dbuf->size());
break;
case bool_:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<bool*>(cpu_data),
[](uint32_t e) { return static_cast<bool>(e); });
break;
case uint8:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<uint8_t*>(cpu_data),
[](uint32_t e) { return static_cast<uint8_t>(e); });
break;
case uint16:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<uint16_t*>(cpu_data),
[](uint32_t e) { return static_cast<uint16_t>(e); });
break;
case int8:
std::transform(
static_cast<const int32_t*>(data),
static_cast<const int32_t*>(data) + num_elements,
static_cast<int8_t*>(cpu_data),
[](int32_t e) { return static_cast<int8_t>(e); });
break;
case int16:
std::transform(
static_cast<const int32_t*>(data),
static_cast<const int32_t*>(data) + num_elements,
static_cast<int16_t*>(cpu_data),
[](int32_t e) { return static_cast<int16_t>(e); });
break;
default:
throw_unsupported_dtype_error(arr.dtype());
}
dbuf->set_cpu_data(cpu_data);
}

void WgpuAllocator::ensure_gpu_data(array& arr) {
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
if (dbuf->gpu_data() || dbuf->size() == 0)
return;
size_t num_elements = dbuf->size() / arr.itemsize();
switch (arr.dtype()) {
case int32:
case uint32:
case float16:
case float32:
dbuf->set_gpu_data(
device_.CreateBufferFromData(dbuf->cpu_data(), dbuf->size()));
break;
case bool_:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<bool*>(dbuf->cpu_data()), num_elements));
break;
case uint8:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<uint8_t*>(dbuf->cpu_data()), num_elements));
break;
case uint16:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<uint16_t*>(dbuf->cpu_data()), num_elements));
break;
case int8:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
static_cast<int8_t*>(dbuf->cpu_data()), num_elements));
break;
case int16:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
static_cast<int16_t*>(dbuf->cpu_data()), num_elements));
break;
default:
throw_unsupported_dtype_error(arr.dtype());
}
}

Buffer WgpuAllocator::malloc_gpu(array& arr) {
return malloc_gpu(arr, arr.nbytes());
}

Buffer WgpuAllocator::malloc_gpu(array& arr, size_t size) {
return Buffer(new DoubleBuffer(device_, arr.dtype(), size));
}

WgpuAllocator& allocator() {
static WgpuAllocator allocator_;
return allocator_;
}

betann::Device& device(mlx::core::Device) {
static betann::Device device;
return device;
}

betann::Device& device(array& arr) {
return device(arr.primitive().device());
}

} // namespace webgpu

namespace metal {

size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}

std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
throw std::runtime_error("[webgpu::device_info] Not implemented");
};

void clear_cache() {}

} // namespace metal

} // namespace mlx::core
77 changes: 77 additions & 0 deletions mlx/backend/webgpu/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/allocator.h"
#include "mlx/device.h"

#include <betann/betann.h>

namespace mlx::core {
class array;
struct Dtype;
} // namespace mlx::core

namespace mlx::core::webgpu {

using allocator::Buffer;

// Holds data for both CPU and GPU.
class DoubleBuffer {
public:
// Allocates memory in CPU.
explicit DoubleBuffer(size_t size);
// Allocates memory in GPU.
DoubleBuffer(betann::Device& device, Dtype dtype, size_t size);

~DoubleBuffer();

void set_cpu_data(void* data) {
assert(!cpu_data_);
cpu_data_ = data;
}
void set_gpu_data(betann::Buffer buffer) {
gpu_data_ = std::move(buffer);
}

void* cpu_data() const {
return cpu_data_;
}
const betann::Buffer& gpu_data() const {
return gpu_data_;
}

size_t size() const {
return size_;
}

private:
size_t size_;
void* cpu_data_ = nullptr;
betann::Buffer gpu_data_;
};

class WgpuAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size, bool allow_swap) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;

void ensure_cpu_data(array& arr, const void* data);
void ensure_gpu_data(array& arr);
Buffer malloc_gpu(array& arr);
Buffer malloc_gpu(array& arr, size_t size);

private:
WgpuAllocator();
friend WgpuAllocator& allocator();

betann::Device& device_;
};

WgpuAllocator& allocator();

betann::Device& device(mlx::core::Device);
betann::Device& device(array& arr);

} // namespace mlx::core::webgpu
Loading