Skip to content

Commit

Permalink
Add webgpu backend
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Feb 17, 2025
1 parent 1762793 commit c0900db
Show file tree
Hide file tree
Showing 11 changed files with 1,192 additions and 0 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
endif()
endif()

if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
endif()

else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
Expand Down Expand Up @@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(MLX_BUILD_WEBGPU)
FetchContent_Declare(
betann
GIT_REPOSITORY https://github.com/frost-beta/betann.git
GIT_TAG db2d5c9bddb75d0d67f675a68ea79ae0fcba723e
EXCLUDE_FROM_ALL)
set(BETANN_BUILD_TESTS OFF)
FetchContent_MakeAvailable(betann)
target_link_libraries(mlx PRIVATE betann)
endif()

if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
Expand Down
2 changes: 2 additions & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/io)

if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
elseif(MLX_BUILD_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()
9 changes: 9 additions & 0 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ bool array::is_tracer() const {
detail::retain_graph();
}

void array::reset_data_ptr() {
void* data_ptr = buffer().raw_ptr();
auto char_offset = sizeof(char) * itemsize() * array_desc_->offset;
array_desc_->data_ptr =
static_cast<void*>(static_cast<char*>(data_ptr) + char_offset);
}

void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
Expand Down Expand Up @@ -165,6 +172,7 @@ void array::copy_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
Expand All @@ -184,6 +192,7 @@ void array::move_shared_buffer(
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
array_desc_->offset = offset;
auto char_offset = sizeof(char) * itemsize() * offset;
auto data_ptr = other.array_desc_->data_ptr;
other.array_desc_->data_ptr = nullptr;
Expand Down
11 changes: 11 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ class array {
return array_desc_->data_size;
}

/** The offset (in elements) of the underlying buffer the array points to.
**/
size_t offset() const {
return array_desc_->offset;
}

allocator::Buffer& buffer() {
return array_desc_->data->buffer;
}
Expand Down Expand Up @@ -413,6 +419,8 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;

void reset_data_ptr();

void set_data(allocator::Buffer buffer, Deleter d = allocator::free);

void set_data(
Expand Down Expand Up @@ -477,6 +485,9 @@ class array {
// The size in elements of the data buffer the array accesses
size_t data_size;

// Offset from the shared data in elements
size_t offset{0};

// Contains useful meta data about the array
Flags flags;

Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../no_metal/event.cpp)
199 changes: 199 additions & 0 deletions mlx/backend/webgpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/webgpu/allocator.h"

#include "mlx/array.h"
#include "mlx/backend/webgpu/utils.h"
#include "mlx/primitives.h"

namespace mlx::core {

namespace allocator {

Allocator& allocator() {
return webgpu::allocator();
}

void* Buffer::raw_ptr() {
return static_cast<webgpu::DoubleBuffer*>(ptr_)->cpu_data();
}

} // namespace allocator

namespace webgpu {

DoubleBuffer::DoubleBuffer(size_t size)
: size_(size), cpu_data_(std::malloc(size)) {}

DoubleBuffer::DoubleBuffer(betann::Device& device, Dtype dtype, size_t size)
: size_(size),
gpu_data_(device.CreateBuffer(
size * gpu_size_factor(dtype),
betann::BufferUsage::Storage | betann::BufferUsage::CopySrc)) {}

DoubleBuffer::~DoubleBuffer() {
std::free(cpu_data_);
}

WgpuAllocator::WgpuAllocator() : device_(webgpu::device(Device::gpu)) {}

Buffer WgpuAllocator::malloc(size_t size, bool allow_swap) {
return Buffer(new DoubleBuffer(size));
}

void WgpuAllocator::free(Buffer buffer) {
delete static_cast<DoubleBuffer*>(buffer.ptr());
}

size_t WgpuAllocator::size(Buffer buffer) const {
return static_cast<DoubleBuffer*>(buffer.ptr())->size();
}

void WgpuAllocator::ensure_cpu_data(array& arr, const void* data) {
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
if (dbuf->cpu_data() || dbuf->size() == 0)
return;
void* cpu_data = std::malloc(dbuf->size());
size_t num_elements = dbuf->size() / arr.itemsize();
switch (arr.dtype()) {
case int32:
case uint32:
case float16:
case float32:
std::memcpy(cpu_data, data, dbuf->size());
break;
case bool_:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<bool*>(cpu_data),
[](uint32_t e) { return static_cast<bool>(e); });
break;
case uint8:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<uint8_t*>(cpu_data),
[](uint32_t e) { return static_cast<uint8_t>(e); });
break;
case uint16:
std::transform(
static_cast<const uint32_t*>(data),
static_cast<const uint32_t*>(data) + num_elements,
static_cast<uint16_t*>(cpu_data),
[](uint32_t e) { return static_cast<uint16_t>(e); });
break;
case int8:
std::transform(
static_cast<const int32_t*>(data),
static_cast<const int32_t*>(data) + num_elements,
static_cast<int8_t*>(cpu_data),
[](int32_t e) { return static_cast<int8_t>(e); });
break;
case int16:
std::transform(
static_cast<const int32_t*>(data),
static_cast<const int32_t*>(data) + num_elements,
static_cast<int16_t*>(cpu_data),
[](int32_t e) { return static_cast<int16_t>(e); });
break;
default:
throw_unsupported_dtype_error(arr.dtype());
}
dbuf->set_cpu_data(cpu_data);
}

void WgpuAllocator::ensure_gpu_data(array& arr) {
auto* dbuf = static_cast<DoubleBuffer*>(arr.buffer().ptr());
if (dbuf->gpu_data() || dbuf->size() == 0)
return;
size_t num_elements = dbuf->size() / arr.itemsize();
switch (arr.dtype()) {
case int32:
case uint32:
case float16:
case float32:
dbuf->set_gpu_data(
device_.CreateBufferFromData(dbuf->cpu_data(), dbuf->size()));
break;
case bool_:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<bool*>(dbuf->cpu_data()), num_elements));
break;
case uint8:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<uint8_t*>(dbuf->cpu_data()), num_elements));
break;
case uint16:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<uint32_t>(
static_cast<uint16_t*>(dbuf->cpu_data()), num_elements));
break;
case int8:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
static_cast<int8_t*>(dbuf->cpu_data()), num_elements));
break;
case int16:
dbuf->set_gpu_data(device_.CreateBufferTransformTo<int32_t>(
static_cast<int16_t*>(dbuf->cpu_data()), num_elements));
break;
default:
throw_unsupported_dtype_error(arr.dtype());
}
}

Buffer WgpuAllocator::malloc_gpu(array& arr) {
return malloc_gpu(arr, arr.nbytes());
}

Buffer WgpuAllocator::malloc_gpu(array& arr, size_t size) {
return Buffer(new DoubleBuffer(device_, arr.dtype(), size));
}

WgpuAllocator& allocator() {
static WgpuAllocator allocator_;
return allocator_;
}

betann::Device& device(mlx::core::Device) {
static betann::Device device;
return device;
}

betann::Device& device(array& arr) {
return device(arr.primitive().device());
}

} // namespace webgpu

namespace metal {

size_t get_active_memory() {
return 0;
}
size_t get_peak_memory() {
return 0;
}
void reset_peak_memory() {}
size_t get_cache_memory() {
return 0;
}
size_t set_memory_limit(size_t, bool) {
return 0;
}
size_t set_cache_limit(size_t) {
return 0;
}
size_t set_wired_limit(size_t) {
return 0;
}

std::unordered_map<std::string, std::variant<std::string, size_t>>
device_info() {
throw std::runtime_error("[webgpu::device_info] Not implemented");
};

void clear_cache() {}

} // namespace metal

} // namespace mlx::core
77 changes: 77 additions & 0 deletions mlx/backend/webgpu/allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/allocator.h"
#include "mlx/device.h"

#include <betann/betann.h>

namespace mlx::core {
class array;
struct Dtype;
} // namespace mlx::core

namespace mlx::core::webgpu {

using allocator::Buffer;

// Holds data for both CPU and GPU.
class DoubleBuffer {
public:
// Allocates memory in CPU.
explicit DoubleBuffer(size_t size);
// Allocates memory in GPU.
DoubleBuffer(betann::Device& device, Dtype dtype, size_t size);

~DoubleBuffer();

void set_cpu_data(void* data) {
assert(!cpu_data_);
cpu_data_ = data;
}
void set_gpu_data(betann::Buffer buffer) {
gpu_data_ = std::move(buffer);
}

void* cpu_data() const {
return cpu_data_;
}
const betann::Buffer& gpu_data() const {
return gpu_data_;
}

size_t size() const {
return size_;
}

private:
size_t size_;
void* cpu_data_ = nullptr;
betann::Buffer gpu_data_;
};

class WgpuAllocator : public allocator::Allocator {
public:
Buffer malloc(size_t size, bool allow_swap) override;
void free(Buffer buffer) override;
size_t size(Buffer buffer) const override;

void ensure_cpu_data(array& arr, const void* data);
void ensure_gpu_data(array& arr);
Buffer malloc_gpu(array& arr);
Buffer malloc_gpu(array& arr, size_t size);

private:
WgpuAllocator();
friend WgpuAllocator& allocator();

betann::Device& device_;
};

WgpuAllocator& allocator();

betann::Device& device(mlx::core::Device);
betann::Device& device(array& arr);

} // namespace mlx::core::webgpu
Loading

0 comments on commit c0900db

Please sign in to comment.