Skip to content

Commit

Permalink
chore: refactor Operation trait
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 27, 2024
1 parent 7e96ed3 commit 633e0e2
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor,
WgpuDevice, WorkgroupCount,
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
};

Expand Down
13 changes: 6 additions & 7 deletions crates/ratchet-core/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor,
WgpuDevice, WorkgroupCount,
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, CompiledOp, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError,
rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError,
RVec, Shape, StorageView, Tensor,
};

Expand Down Expand Up @@ -268,7 +267,7 @@ impl Operation for Matmul {

fn storage_bind_group_layout(
&self,
inplace: bool,
_inplace: bool,
) -> Result<BindGroupLayoutDescriptor, OperationError> {
let (A, B) = (&self.lhs, &self.rhs);
let layout = match (A.dt(), B.dt()) {
Expand All @@ -289,9 +288,9 @@ impl Operation for Matmul {
let N = spec.n() as u32;
let K = spec.k() as u32;

let a_offset = MatmulSpec::batch_offset(spec.a_stack() as _, M, K, &kernel_element);
let b_offset = MatmulSpec::batch_offset(spec.b_stack() as _, K, N, &kernel_element);
let c_offset = MatmulSpec::batch_offset(spec.c_stack() as _, M, N, &kernel_element);
let a_offset = MatmulSpec::batch_offset(spec.a_stack() as _, M, K, kernel_element);
let b_offset = MatmulSpec::batch_offset(spec.b_stack() as _, K, N, kernel_element);
let c_offset = MatmulSpec::batch_offset(spec.c_stack() as _, M, N, kernel_element);

Ok(MatmulMeta::new(M, N, K, a_offset, b_offset, c_offset))
}
Expand Down
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor,
WgpuDevice, WorkgroupCount,
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
};

Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice};
use crate::{
ops::*, rvec, shape, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable,
ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable,
GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType,
TensorId,
};
Expand Down

0 comments on commit 633e0e2

Please sign in to comment.