Skip to content

Commit

Permalink
Vectorize KV-Cache by using Vec4
Browse files Browse the repository at this point in the history
  • Loading branch information
officialcjunior committed Jun 21, 2024
1 parent 0d19508 commit 5bd5740
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions crates/ratchet-core/src/ops/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ impl Cache {
builder: &mut WgslKernelBuilder,
_: bool,
) -> Result<(), OperationError> {
builder.register_storage("C", BindingMode::ReadWrite, Array::<P>::default());
builder.register_storage("S", BindingMode::ReadOnly, Array::<P>::default());
builder.register_storage("D", BindingMode::ReadWrite, Array::<P>::default());
builder.register_storage("C", BindingMode::ReadWrite, Array::<vec4<f32>>::default());
builder.register_storage("S", BindingMode::ReadOnly, Array::<vec4<f32>>::default());
builder.register_storage("D", BindingMode::ReadWrite, Array::<vec4<f32>>::default());

builder.register_uniform();
Ok(())
Expand Down Expand Up @@ -68,29 +68,29 @@ impl Cache {
kernel_builder.write_index_to_offset();

kernel_builder.write_main(wgsl! {
//Dispatch 1 thread per output element
//Dispatch 1 thread per output element (vec4<f32>)
//dst_offset is index into the output buffer (1D)
let x_offset = workgroup_id.x * 64u;
let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
if (dst_offset >= metadata.dst_numel) {
if (dst_offset >= metadata.dst_numel / 4u) {
return;
}
//Convert 1D offset into 4D index
var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);
// Convert 1D offset into 4D index
var dst_index = offsetToNdIndex(dst_offset * 4u, metadata.dst_stride);

let dim = metadata.dim;
if (dst_index[dim] < metadata.cum0) {
//Inside cache, just copy from cache to DST
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u;
D[dst_offset] = C[src_offset];
return;
}

if (dst_index[dim] < metadata.cum1) {
//Inside src, copy from src to cache and then to DST
let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u;
dst_index[dim] -= metadata.cum0;
let src_offset = ndIndexToOffset(dst_index, metadata.src_stride);
let src_offset = ndIndexToOffset(dst_index, metadata.src_stride) / 4u;
let val = S[src_offset];
C[cache_offset] = val;
D[dst_offset] = val;
Expand Down Expand Up @@ -151,7 +151,12 @@ impl MetaOperation for Cache {
}

fn kernel_element(&self, _dst: &Tensor) -> KernelElement {
KernelElement::Scalar
let numel = self.input.shape().numel();
if numel % 4 == 0 {
KernelElement::Vec4
} else {
KernelElement::Scalar
}
}

fn calculate_dispatch(&self, dst: &Tensor) -> Result<Workload, OperationError> {
Expand Down

0 comments on commit 5bd5740

Please sign in to comment.