Skip to content

Commit

Permalink
Use Array::<P>> while registering storage and add more conditions f…
Browse files Browse the repository at this point in the history
…or `KernelElement`
  • Loading branch information
officialcjunior committed Jun 23, 2024
1 parent 5bd5740 commit 18f24fb
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions crates/ratchet-core/src/ops/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ impl Cache {
builder: &mut WgslKernelBuilder,
_: bool,
) -> Result<(), OperationError> {
builder.register_storage("C", BindingMode::ReadWrite, Array::<vec4<f32>>::default());
builder.register_storage("S", BindingMode::ReadOnly, Array::<vec4<f32>>::default());
builder.register_storage("D", BindingMode::ReadWrite, Array::<vec4<f32>>::default());
builder.register_storage("C", BindingMode::ReadWrite, Array::<P>::default());
builder.register_storage("S", BindingMode::ReadOnly, Array::<P>::default());
builder.register_storage("D", BindingMode::ReadWrite, Array::<P>::default());

builder.register_uniform();
Ok(())
Expand Down Expand Up @@ -150,10 +150,13 @@ impl MetaOperation for Cache {
rvec![&self.cache, &self.source]
}

fn kernel_element(&self, _dst: &Tensor) -> KernelElement {
let numel = self.input.shape().numel();
fn kernel_element(&self, dst: &Tensor) -> KernelElement {
let numel = dst.shape().numel();

if numel % 4 == 0 {
KernelElement::Vec4
} else if numel % 2 == 0 {
KernelElement::Vec2
} else {
KernelElement::Scalar
}
Expand Down

0 comments on commit 18f24fb

Please sign in to comment.