From 633e0e26d561ced2312834acb676a7002e8c999e Mon Sep 17 00:00:00 2001 From: FL33TW00D Date: Sat, 27 Jan 2024 16:40:32 +0000 Subject: [PATCH] chore: refactor Operation trait --- crates/ratchet-core/src/ops/binary.rs | 5 ++--- crates/ratchet-core/src/ops/matmul.rs | 13 ++++++------- crates/ratchet-core/src/ops/softmax.rs | 5 ++--- crates/ratchet-core/src/tensor.rs | 2 +- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/crates/ratchet-core/src/ops/binary.rs b/crates/ratchet-core/src/ops/binary.rs index b87f2131..9c658ec5 100644 --- a/crates/ratchet-core/src/ops/binary.rs +++ b/crates/ratchet-core/src/ops/binary.rs @@ -3,10 +3,9 @@ use encase::ShaderType; use crate::{ gpu::{ - BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor, - WgpuDevice, WorkgroupCount, + BindGroupLayoutDescriptor, WorkgroupCount, }, - rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, + rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView, Tensor, }; diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index 211c0921..a780cefb 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -5,10 +5,9 @@ use encase::ShaderType; use crate::{ gpu::{ - BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor, - WgpuDevice, WorkgroupCount, + BindGroupLayoutDescriptor, WorkgroupCount, }, - rvec, wgc, CompiledOp, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError, + rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, Shape, StorageView, Tensor, }; @@ -268,7 +267,7 @@ impl Operation for Matmul { fn storage_bind_group_layout( &self, - inplace: bool, + _inplace: bool, ) -> Result { let (A, B) = (&self.lhs, &self.rhs); let layout = match (A.dt(), B.dt()) { @@ -289,9 +288,9 @@ impl Operation for Matmul { let N = spec.n() as u32; let K = spec.k() as u32; - let a_offset = MatmulSpec::batch_offset(spec.a_stack() as _, M, K, &kernel_element); - let b_offset = MatmulSpec::batch_offset(spec.b_stack() as _, K, N, &kernel_element); - let c_offset = MatmulSpec::batch_offset(spec.c_stack() as _, M, N, &kernel_element); + let a_offset = MatmulSpec::batch_offset(spec.a_stack() as _, M, K, kernel_element); + let b_offset = MatmulSpec::batch_offset(spec.b_stack() as _, K, N, kernel_element); + let c_offset = MatmulSpec::batch_offset(spec.c_stack() as _, M, N, kernel_element); Ok(MatmulMeta::new(M, N, K, a_offset, b_offset, c_offset)) } diff --git a/crates/ratchet-core/src/ops/softmax.rs b/crates/ratchet-core/src/ops/softmax.rs index ca36d02d..129f9520 100644 --- a/crates/ratchet-core/src/ops/softmax.rs +++ b/crates/ratchet-core/src/ops/softmax.rs @@ -3,10 +3,9 @@ use encase::ShaderType; use crate::{ gpu::{ - BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor, - WgpuDevice, WorkgroupCount, + BindGroupLayoutDescriptor, WorkgroupCount, }, - rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, + rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView, Tensor, }; diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index b6718713..f733ee2d 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -1,6 +1,6 @@ use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice}; use crate::{ - ops::*, rvec, shape, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, + ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId, };