Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/qmm #42

Merged
merged 8 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
tool: cargo-nextest

- name: (linux) install llvmpipe, lavapipe, vulkan sdk, alsa
- name: (linux) install lavapipe, vulkan sdk, alsa
if: matrix.os == 'ubuntu-latest'
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
### [Discord](https://discord.gg/XFe33KQTG4)

A web-first, cross-platform ML framework.

2 changes: 1 addition & 1 deletion crates/ratchet-core/kernels/qgemm_vec4.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Meta {
@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
Expand Down
48 changes: 48 additions & 0 deletions crates/ratchet-core/kernels/sgemm_vec2.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//Unoptimized, only gets 500GFLOP
@group(0) @binding(0)
var<storage, read> A: array<vec2<f32>>;

@group(0) @binding(1)
var<storage, read> B: array<vec2<f32>>;

@group(0) @binding(2)
var<storage, read_write> C: array<vec2<f32>>;

struct Meta {
M: u32,
N: u32,
K: u32,
MD2: u32,
ND2: u32,
KD2: u32,
MD4: u32,
ND4: u32,
KD4: u32,
A_OFFSET: u32,
B_OFFSET: u32,
C_OFFSET: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let a_offset = global_id.z * metadata.A_OFFSET;
let b_offset = global_id.z * metadata.B_OFFSET;
let c_offset = global_id.z * metadata.C_OFFSET;

let cRow = global_id.x;
let cCol = global_id.y;
if (cRow < metadata.M && cCol < metadata.ND2) {
var tmp = vec2<f32>();
for (var k = 0u; k < metadata.KD2; k++) {
let a = A[a_offset + (cRow * metadata.KD2 + k)];
tmp += vec2<f32>(a.x) * B[b_offset + (k * metadata.N + cCol)];
tmp += vec2<f32>(a.y) * B[b_offset + (k * metadata.N + cCol + (1u * metadata.ND2))];
}
C[c_offset + (cRow * metadata.ND2 + cCol)] = tmp;
}
}
53 changes: 53 additions & 0 deletions crates/ratchet-core/kernels/sgemm_vec4.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//Unoptimized, only gets 500GFLOP
@group(0) @binding(0)
var<storage, read> A: array<vec4<f32>>;

@group(0) @binding(1)
var<storage, read> B: array<vec4<f32>>;

@group(0) @binding(2)
var<storage, read_write> C: array<vec4<f32>>;

struct Meta {
M: u32,
N: u32,
K: u32,
MD2: u32,
ND2: u32,
KD2: u32,
MD4: u32,
ND4: u32,
KD4: u32,
A_OFFSET: u32,
B_OFFSET: u32,
C_OFFSET: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let a_offset = global_id.z * metadata.A_OFFSET;
let b_offset = global_id.z * metadata.B_OFFSET;
let c_offset = global_id.z * metadata.C_OFFSET;

let cRow = global_id.x;
let cCol = global_id.y;
if (cRow < metadata.M && cCol < metadata.ND4) {
var tmp = vec4<f32>();
for (var k = 0u; k < metadata.KD4; k++) {
let a = A[a_offset + (cRow * metadata.KD4 + k)];
let b_step = k * metadata.N + cCol; //4 rows per iter
let b_stride = metadata.ND4;

tmp = fma(vec4<f32>(a.x), B[b_offset + b_step], tmp);
tmp = fma(vec4<f32>(a.y), B[b_offset + (b_step + b_stride)], tmp);
tmp = fma(vec4<f32>(a.z), B[b_offset + (b_step + (2u * b_stride))], tmp);
tmp = fma(vec4<f32>(a.w), B[b_offset + (b_step + (3u * b_stride))], tmp);
}
C[c_offset + (cRow * metadata.ND4 + cCol)] = tmp;
}
}
22 changes: 7 additions & 15 deletions crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::gpu::{
BindGroupDescriptor, BindGroupEntry, BindGroupLayoutHandle, ComputePipelineHandle,
GpuBindGroup, WgpuDevice, WorkgroupCount,
BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice,
WorkgroupCount,
};
use crate::{drvec, rvec, RVec, Tensor};
use derive_new::new;
use wgpu::DynamicOffset;

//Compiled op represents a single kernel invocation
//TODO: We need to be more general here, enum with encoder.copy_buffer_to_buffer as a COPY
//TODO: We need to be more general here, enum with encoder.copy_buffer_to_buffer as a COPY, and
//compiledOp as compute
#[derive(Debug, new)]
pub struct CompiledOp {
pipeline_handle: ComputePipelineHandle,
Expand All @@ -19,30 +20,21 @@ pub struct CompiledOp {
impl CompiledOp {
const MAX_BINDINGS_PER_GROUP: usize = 4;

//TODO: Should return a Result
pub fn create_storage_bind_groups(
srcs: &[&Tensor],
dst: &Tensor,
bind_group_layouts: RVec<BindGroupLayoutHandle>,
device: &WgpuDevice,
) -> RVec<GpuBindGroup> {
let mut binding_counter: usize = 0;
let mut bind_group_entries = drvec![];

for tensor in srcs.iter().chain(std::iter::once(&dst)) {
let storage_guard = tensor.storage();
let storage = storage_guard.as_ref().unwrap();
let gpu_buf = &storage.try_gpu().unwrap().inner;
bind_group_entries.push(BindGroupEntry {
handle: gpu_buf.handle,
offset: 0,
size: Some(gpu_buf.size().try_into().unwrap()),
});
binding_counter += 1;
bind_group_entries.append(&mut tensor.bindings());
}

let mut storage_groups = rvec![];
for (group_index, bind_group_layout) in bind_group_layouts.iter().enumerate() {
let group_range = Self::group_range(group_index, binding_counter);
let group_range = Self::group_range(group_index, bind_group_entries.len());
let entries = bind_group_entries[group_range].into();
let layout = *bind_group_layout;

Expand Down
41 changes: 41 additions & 0 deletions crates/ratchet-core/src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::num::NonZeroU64;

use half::{bf16, f16};
use wgpu::{BufferAddress, BufferSize};

use crate::{rvec, RVec};

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
pub enum DType {
Expand All @@ -9,6 +14,7 @@ pub enum DType {
F32,
I32,
U32,
WQ8, //Packed Q8 (|--4xQ8(u32)--| |--f32--|)
}

impl DType {
Expand All @@ -21,7 +27,42 @@ impl DType {
DType::F32 => 4,
DType::I32 => 4,
DType::U32 => 4,
DType::WQ8 => 4, //Only works because they're both 4 bytes
}
}

pub fn segments(&self, total_bytes: usize) -> RVec<BufferSegment> {
match self {
DType::WQ8 => {
let weights_size = total_bytes / 5 * 4;
assert!(weights_size % 256 == 0); //storage buffer alignment
let weights = BufferSegment::new(0, Some(weights_size as u64));

let absmax_size = total_bytes - weights_size;
assert!(absmax_size % 256 == 0); //storage buffer alignment
let absmax = BufferSegment::new(weights_size as u64, Some(absmax_size as u64));
rvec![weights, absmax]
}
_ => {
rvec![BufferSegment::new(0, Some(total_bytes as u64))]
}
}
}
}

#[derive(Debug)]
pub struct BufferSegment {
pub offset: BufferAddress,
pub size: Option<BufferSize>,
}

impl BufferSegment {
pub fn new(offset: BufferAddress, size: Option<u64>) -> Self {
if let Some(size) = size {
assert!(size % 256 == 0); //storage buffer alignment
}
let size = size.map(NonZeroU64::new).unwrap();
Self { offset, size }
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/ratchet-core/src/gpu/pools/buffer_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl BufferDescriptor {
}
}

//All slotmap keys are COPY
slotmap::new_key_type! { pub struct GpuBufferHandle; }

/// A reference-counter baked buffer.
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/gpu/pools/pipeline_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl ComputePipelinePool {
) -> ComputePipelineHandle {
self.inner.get_or_create(desc, |desc| {
let kernel_key = desc.build_kernel_key();
let shader = KERNELS.get(kernel_key.as_str()).unwrap();
let shader = KERNELS.get(kernel_key.as_str()).expect("Kernel not found");
let label = Some(kernel_key.as_str());

let shader_module_desc = wgpu::ShaderModuleDescriptor {
Expand Down
12 changes: 12 additions & 0 deletions crates/ratchet-core/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ lazy_static! {
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl"
),
);
m.insert(
"sgemm_vec2",
include_str!(
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec2.wgsl"
),
);
m.insert(
"sgemm_vec4",
include_str!(
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec4.wgsl"
),
);
m
};
}
4 changes: 1 addition & 3 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use encase::internal::WriteInto;
use encase::ShaderType;

use crate::gpu::{BindGroupLayoutHandle, CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN};
use crate::gpu::{CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN};
use crate::{Binary, CompiledOp, InvariantError, Matmul, RVec, StorageView, Tensor};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -83,8 +83,6 @@ pub trait Operation: Debug + 'static {
device: &WgpuDevice,
) -> Result<CompiledOp, OperationError>;

fn storage_layout(&self, device: &WgpuDevice) -> Result<BindGroupLayoutHandle, OperationError>;

fn check_invariants(srcs: &[&Tensor]) -> Result<(), OperationError>;

/// # Output Inference
Expand Down
11 changes: 4 additions & 7 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, BindGroupLayoutHandle, ComputePipelineDescriptor, CpuUniform,
PipelineLayoutDescriptor, WgpuDevice, WorkgroupCount,
BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor,
WgpuDevice, WorkgroupCount,
},
rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
Expand Down Expand Up @@ -61,10 +61,6 @@ impl Operation for Binary {
rvec![&self.lhs, &self.rhs]
}

fn storage_layout(&self, device: &WgpuDevice) -> Result<BindGroupLayoutHandle, OperationError> {
Ok(device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::binary())?)
}

//TODO: we can refactor this into composite methods and share a single `compile` impl on the
//trait
fn compile(
Expand All @@ -79,7 +75,8 @@ impl Operation for Binary {
let offset = uniform.write(&BinaryMeta { M, N })?;
let wgcx = WorkgroupCount::div_ceil(M as _, 64);

let storage_layout = self.storage_layout(device)?;
let storage_layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::binary())?;
let uniform_layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?;
let pipeline_layout = device.get_or_create_pipeline_layout(&PipelineLayoutDescriptor {
Expand Down
Loading
Loading