diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index ceaac551..a272b302 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -77,14 +77,15 @@ pub trait Operation: Debug + 'static { ///Typically contains shapes or strides. type Meta: OpMetadata; + /// Return the file stem of the kernel source file. + fn kernel_name(&self) -> &'static str; + fn srcs(&self) -> RVec<&Tensor>; fn supports_inplace(&self) -> bool { false } - fn kernel_name(&self) -> &'static str; - /// # Kernel Element /// /// Determine the largest possible unit data type that can be used (e.g f32, vec2, vec4) @@ -130,12 +131,12 @@ pub trait Operation: Debug + 'static { entries: rvec![storage_layout, uniform_layout], })?; - let pipeline_handle = - device.get_or_create_compute_pipeline(&ComputePipelineDescriptor { - pipeline_layout, - kernel_name: self.kernel_name(), - kernel_element, - })?; + let pipeline_descriptor = ComputePipelineDescriptor { + pipeline_layout, + kernel_name: self.kernel_name(), + kernel_element, + }; + let pipeline_handle = device.get_or_create_compute_pipeline(&pipeline_descriptor)?; let storage_bind_groups = CompiledOp::create_storage_bind_groups( &self.srcs(), diff --git a/crates/ratchet-core/src/ops/binary.rs b/crates/ratchet-core/src/ops/binary.rs index 9c658ec5..184377f7 100644 --- a/crates/ratchet-core/src/ops/binary.rs +++ b/crates/ratchet-core/src/ops/binary.rs @@ -2,11 +2,9 @@ use derive_new::new; use encase::ShaderType; use crate::{ - gpu::{ - BindGroupLayoutDescriptor, WorkgroupCount, - }, - rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, - StorageView, Tensor, + gpu::{BindGroupLayoutDescriptor, WorkgroupCount}, + rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView, + Tensor, }; #[derive(Debug, Clone)] diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index a780cefb..2a797488 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -4,11 +4,9 @@ use derive_new::new; use encase::ShaderType; use crate::{ - gpu::{ - BindGroupLayoutDescriptor, WorkgroupCount, - }, - rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError, - RVec, Shape, StorageView, Tensor, + gpu::{BindGroupLayoutDescriptor, WorkgroupCount}, + rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, Shape, + StorageView, Tensor, }; // Defines a matrix multiplication operation. diff --git a/crates/ratchet-core/src/ops/softmax.rs b/crates/ratchet-core/src/ops/softmax.rs index 129f9520..22a13fdd 100644 --- a/crates/ratchet-core/src/ops/softmax.rs +++ b/crates/ratchet-core/src/ops/softmax.rs @@ -2,11 +2,9 @@ use derive_new::new; use encase::ShaderType; use crate::{ - gpu::{ - BindGroupLayoutDescriptor, WorkgroupCount, - }, - rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, - StorageView, Tensor, + gpu::{BindGroupLayoutDescriptor, WorkgroupCount}, + rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView, + Tensor, }; #[derive(new, Debug, Clone)] diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index f733ee2d..f1e8d6b4 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -1,8 +1,7 @@ use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice}; use crate::{ - ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, - GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, - TensorId, + ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, GPUBuffer, + Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId, }; use crate::{BinaryOp, LazyOp};