diff --git a/crates/ratchet-core/src/ops/cache.rs b/crates/ratchet-core/src/ops/cache.rs
index 5ff00d7b..f0e1a4b1 100644
--- a/crates/ratchet-core/src/ops/cache.rs
+++ b/crates/ratchet-core/src/ops/cache.rs
@@ -35,9 +35,9 @@ impl Cache {
builder: &mut WgslKernelBuilder,
_: bool,
) -> Result<(), OperationError> {
- builder.register_storage("C", BindingMode::ReadWrite, Array::
::default());
- builder.register_storage("S", BindingMode::ReadOnly, Array::
::default());
- builder.register_storage("D", BindingMode::ReadWrite, Array::
::default());
+ builder.register_storage("C", BindingMode::ReadWrite, Array::>::default());
+ builder.register_storage("S", BindingMode::ReadOnly, Array::>::default());
+ builder.register_storage("D", BindingMode::ReadWrite, Array::>::default());
builder.register_uniform();
Ok(())
@@ -68,29 +68,29 @@ impl Cache {
kernel_builder.write_index_to_offset();
kernel_builder.write_main(wgsl! {
- //Dispatch 1 thread per output element
+ //Dispatch 1 thread per output element (vec4)
//dst_offset is index into the output buffer (1D)
let x_offset = workgroup_id.x * 64u;
let dst_offset = (workgroup_id.y * num_workgroups.x * 64u) + x_offset + local_invocation_index;
- if (dst_offset >= metadata.dst_numel) {
+ if (dst_offset >= metadata.dst_numel / 4u) {
return;
}
- //Convert 1D offset into 4D index
- var dst_index = offsetToNdIndex(dst_offset, metadata.dst_stride);
+ // Convert 1D offset into 4D index
+ var dst_index = offsetToNdIndex(dst_offset * 4u, metadata.dst_stride);
let dim = metadata.dim;
if (dst_index[dim] < metadata.cum0) {
//Inside cache, just copy from cache to DST
- let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
+ let src_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u;
D[dst_offset] = C[src_offset];
return;
}
if (dst_index[dim] < metadata.cum1) {
//Inside src, copy from src to cache and then to DST
- let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride);
+ let cache_offset = ndIndexToOffset(dst_index, metadata.cache_stride) / 4u;
dst_index[dim] -= metadata.cum0;
- let src_offset = ndIndexToOffset(dst_index, metadata.src_stride);
+ let src_offset = ndIndexToOffset(dst_index, metadata.src_stride) / 4u;
let val = S[src_offset];
C[cache_offset] = val;
D[dst_offset] = val;
@@ -151,7 +151,12 @@ impl MetaOperation for Cache {
}
fn kernel_element(&self, _dst: &Tensor) -> KernelElement {
- KernelElement::Scalar
+ let numel = self.input.shape().numel();
+ if numel % 4 == 0 {
+ KernelElement::Vec4
+ } else {
+ KernelElement::Scalar
+ }
}
fn calculate_dispatch(&self, dst: &Tensor) -> Result {