From 18f24fb887635f362c281e461e50e604175af90d Mon Sep 17 00:00:00 2001 From: Aswin C Date: Sun, 23 Jun 2024 18:59:57 +0530 Subject: [PATCH] Use `Array::

>` while registering storage and add more conditions for `KernelElement` --- crates/ratchet-core/src/ops/cache.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/crates/ratchet-core/src/ops/cache.rs b/crates/ratchet-core/src/ops/cache.rs index f0e1a4b1..126509f3 100644 --- a/crates/ratchet-core/src/ops/cache.rs +++ b/crates/ratchet-core/src/ops/cache.rs @@ -35,9 +35,9 @@ impl Cache { builder: &mut WgslKernelBuilder, _: bool, ) -> Result<(), OperationError> { - builder.register_storage("C", BindingMode::ReadWrite, Array::>::default()); - builder.register_storage("S", BindingMode::ReadOnly, Array::>::default()); - builder.register_storage("D", BindingMode::ReadWrite, Array::>::default()); + builder.register_storage("C", BindingMode::ReadWrite, Array::

::default()); + builder.register_storage("S", BindingMode::ReadOnly, Array::

::default()); + builder.register_storage("D", BindingMode::ReadWrite, Array::

::default()); builder.register_uniform(); Ok(()) @@ -150,10 +150,13 @@ impl MetaOperation for Cache { rvec![&self.cache, &self.source] } - fn kernel_element(&self, _dst: &Tensor) -> KernelElement { - let numel = self.input.shape().numel(); + fn kernel_element(&self, dst: &Tensor) -> KernelElement { + let numel = dst.shape().numel(); + if numel % 4 == 0 { KernelElement::Vec4 + } else if numel % 2 == 0 { + KernelElement::Vec2 } else { KernelElement::Scalar }