Skip to content

Commit

Permalink
Add webgpu backend with binary ops support
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Jan 24, 2025
1 parent e6a7ab9 commit c693ad8
Show file tree
Hide file tree
Showing 12 changed files with 554 additions and 91 deletions.
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_WEBGPU "Build webgpu backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -52,6 +53,10 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
endif()
endif()

if(MLX_BUILD_WEBGPU AND MLX_BUILD_METAL)
message(FATAL_ERROR "Can not build both webgpu and metal backends.")
endif()

else()
set(MLX_BUILD_METAL OFF)
message(WARNING "MLX is prioritised for Apple silicon systems using macOS.")
Expand Down Expand Up @@ -114,6 +119,17 @@ elseif(MLX_BUILD_METAL)
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
endif()

if(MLX_BUILD_WEBGPU)
FetchContent_Declare(
betann
GIT_REPOSITORY https://github.com/frost-beta/betann.git
GIT_TAG ba11e5367a08bd28c6095e95be7d55710390354d
EXCLUDE_FROM_ALL)
set(BETANN_BUILD_TESTS OFF)
FetchContent_MakeAvailable(betann)
target_link_libraries(mlx PRIVATE betann)
endif()

if(WIN32)
if(MSVC)
# GGUF does not build with MSVC.
Expand Down
94 changes: 3 additions & 91 deletions examples/cpp/tutorial.cpp
Original file line number Diff line number Diff line change
@@ -1,99 +1,11 @@
// Copyright © 2023 Apple Inc.

#include <cassert>
#include <iostream>

#include "mlx/mlx.h"

namespace mx = mlx::core;

void array_basics() {
// Make a scalar array:
mx::array x(1.0);

// Get the value out of it:
auto s = x.item<float>();
assert(s == 1.0);

// Scalars have a size of 1:
size_t size = x.size();
assert(size == 1);

// Scalars have 0 dimensions:
int ndim = x.ndim();
assert(ndim == 0);

// The shape should be an empty vector:
auto shape = x.shape();
assert(shape.empty());

// The datatype should be float32:
auto dtype = x.dtype();
assert(dtype == mx::float32);

// Specify the dtype when constructing the array:
x = mx::array(1, mx::int32);
assert(x.dtype() == mx::int32);
x.item<int>(); // OK
// x.item<float>(); // Undefined!

// Make a multidimensional array:
x = mx::array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
// mlx is row-major by default so the first row of this array
// is [1.0, 2.0] and the second row is [3.0, 4.0]

// Make an array of shape {2, 2} filled with ones:
auto y = mx::ones({2, 2});

// Pointwise add x and y:
auto z = mx::add(x, y);

// Same thing:
z = x + y;

// mlx is lazy by default. At this point `z` only
// has a shape and a type but no actual data:
assert(z.dtype() == mx::float32);
assert(z.shape(0) == 2);
assert(z.shape(1) == 2);

// To actually run the computation you must evaluate `z`.
// Under the hood, mlx records operations in a graph.
// The variable `z` is a node in the graph which points to its operation
// and inputs. When `eval` is called on an array (or arrays), the array and
// all of its dependencies are recursively evaluated to produce the result.
// Once an array is evaluated, it has data and is detached from its inputs.
mx::eval(z);

// Of course the array can still be an input to other operations. You can
// even call eval on the array again, this will just be a no-op:
mx::eval(z); // no-op

// Some functions or methods on arrays implicitly evaluate them. For example
// accessing a value in an array or printing the array implicitly evaluate it:
z = mx::ones({1});
z.item<float>(); // implicit evaluation

z = mx::ones({2, 2});
std::cout << z << std::endl; // implicit evaluation
}

void automatic_differentiation() {
auto fn = [](mx::array x) { return mx::square(x); };

// Computing the derivative function of a function
auto grad_fn = mx::grad(fn);
// Call grad_fn on the input to get the derivative
auto x = mx::array(1.5);
auto dfdx = grad_fn(x);
// dfdx is 2 * x

// Get the second derivative by composing grad with grad
auto d2fdx2 = mx::grad(mx::grad(fn))(x);
// d2fdx2 is 2
}

int main() {
array_basics();
automatic_differentiation();
mx::array x({1.0, 2.0, 3.0});
mx::array y({4.0, 5.0, 6.0});
std::cout << x + y << std::endl;
}
2 changes: 2 additions & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ endif()

if(MLX_BUILD_METAL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/metal)
elseif(MLX_BUILD_WEBGPU)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/webgpu)
else()
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/no_metal)
endif()
23 changes: 23 additions & 0 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,29 @@ bool array::is_tracer() const {
detail::retain_graph();
}

void array::set_gpu_data(allocator::Buffer buffer, Deleter d) {
array_desc_->gpu_data = std::make_shared<Data>(buffer, d);
}

void array::set_cpu_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
}

void array::set_data_info(size_t data_size, Strides strides, Flags flags) {
array_desc_->data_size = data_size;
array_desc_->strides = std::move(strides);
array_desc_->flags = flags;
}

void array::set_data_info() {
array_desc_->data_size = size();
array_desc_->flags.contiguous = true;
array_desc_->flags.row_contiguous = true;
auto max_dim = std::max_element(shape().begin(), shape().end());
array_desc_->flags.col_contiguous = size() <= 1 || size() == *max_dim;
}

void array::set_data(allocator::Buffer buffer, Deleter d) {
array_desc_->data = std::make_shared<Data>(buffer, d);
array_desc_->data_ptr = buffer.raw_ptr();
Expand Down
11 changes: 11 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ class array {
return allocator::allocator().size(buffer());
}

const std::shared_ptr<Data>& gpu_data_shared_ptr() const {
return array_desc_->gpu_data;
}

// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
Expand Down Expand Up @@ -401,6 +405,11 @@ class array {
// Check if the array is a tracer array
bool is_tracer() const;

void set_gpu_data(allocator::Buffer buffer, Deleter d);
void set_cpu_data(allocator::Buffer buffer, Deleter d = allocator::free);
void set_data_info();
void set_data_info(size_t data_size, Strides strides, Flags flags);

void set_data(allocator::Buffer buffer, Deleter d = allocator::free);

void set_data(
Expand Down Expand Up @@ -455,6 +464,8 @@ class array {
// and should not be detached from the graph
bool is_tracer{false};

std::shared_ptr<Data> gpu_data;

// This is a shared pointer so that *different* arrays
// can share the underlying data buffer.
std::shared_ptr<Data> data;
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
target_sources(
mlx
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp)
16 changes: 16 additions & 0 deletions mlx/backend/webgpu/allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright © 2023 Apple Inc.

#include "mlx/allocator.h"

namespace mlx::core::allocator {

Allocator& allocator() {
static CommonAllocator allocator_;
return allocator_;
}

void* Buffer::raw_ptr() {
return static_cast<size_t*>(ptr_) + 1;
}

} // namespace mlx::core::allocator
13 changes: 13 additions & 0 deletions mlx/backend/webgpu/device.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/webgpu/device.h"

namespace mlx::core::metal {

betann::Device& device(mlx::core::Device) {
// FIXME(zcbenz): Make it live longer than StreamThread.
static betann::Device* device = new betann::Device;
return *device;
}

} // namespace mlx::core::metal
13 changes: 13 additions & 0 deletions mlx/backend/webgpu/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include <betann/betann.h>

#include "mlx/device.h"

namespace mlx::core::metal {

betann::Device& device(mlx::core::Device);

} // namespace mlx::core::metal
47 changes: 47 additions & 0 deletions mlx/backend/webgpu/event.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright © 2024 Apple Inc.

#include "mlx/event.h"

#include <condition_variable>
#include <mutex>

namespace mlx::core {

struct EventCounter {
uint64_t value{0};
std::mutex mtx;
std::condition_variable cv;
};

Event::Event(const Stream& stream) : stream_(stream) {
auto dtor = [](void* ptr) { delete static_cast<EventCounter*>(ptr); };
event_ = std::shared_ptr<void>(new EventCounter{}, dtor);
}

void Event::wait() {
auto ec = static_cast<EventCounter*>(raw_event().get());
std::unique_lock<std::mutex> lk(ec->mtx);
if (ec->value >= value()) {
return;
}
ec->cv.wait(lk, [value = value(), ec] { return ec->value >= value; });
}

void Event::signal() {
auto ec = static_cast<EventCounter*>(raw_event().get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
ec->value = value();
}
ec->cv.notify_all();
}

bool Event::is_signaled() const {
auto ec = static_cast<EventCounter*>(raw_event().get());
{
std::lock_guard<std::mutex> lk(ec->mtx);
return (ec->value > value());
}
}

} // namespace mlx::core
Loading

0 comments on commit c693ad8

Please sign in to comment.