Skip to content

Commit

Permalink
chore: refactor Operation trait
Browse files Browse the repository at this point in the history
  • Loading branch information
FL33TW00D committed Jan 27, 2024
1 parent 633e0e2 commit 0df9377
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 26 deletions.
17 changes: 9 additions & 8 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ pub trait Operation: Debug + 'static {
///Typically contains shapes or strides.
type Meta: OpMetadata;

/// Return the file stem of the kernel source file.
fn kernel_name(&self) -> &'static str;

fn srcs(&self) -> RVec<&Tensor>;

fn supports_inplace(&self) -> bool {
false
}

fn kernel_name(&self) -> &'static str;

/// # Kernel Element
///
/// Determine the largest possible unit data type that can be used (e.g f32, vec2<f32>, vec4<f32>)
Expand Down Expand Up @@ -130,12 +131,12 @@ pub trait Operation: Debug + 'static {
entries: rvec![storage_layout, uniform_layout],
})?;

let pipeline_handle =
device.get_or_create_compute_pipeline(&ComputePipelineDescriptor {
pipeline_layout,
kernel_name: self.kernel_name(),
kernel_element,
})?;
let pipeline_descriptor = ComputePipelineDescriptor {
pipeline_layout,
kernel_name: self.kernel_name(),
kernel_element,
};
let pipeline_handle = device.get_or_create_compute_pipeline(&pipeline_descriptor)?;

let storage_bind_groups = CompiledOp::create_storage_bind_groups(
&self.srcs(),
Expand Down
8 changes: 3 additions & 5 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ use derive_new::new;
use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
gpu::{BindGroupLayoutDescriptor, WorkgroupCount},
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView,
Tensor,
};

#[derive(Debug, Clone)]
Expand Down
8 changes: 3 additions & 5 deletions crates/ratchet-core/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@ use derive_new::new;
use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError,
RVec, Shape, StorageView, Tensor,
gpu::{BindGroupLayoutDescriptor, WorkgroupCount},
rvec, wgc, DType, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, Shape,
StorageView, Tensor,
};

// Defines a matrix multiplication operation.
Expand Down
8 changes: 3 additions & 5 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ use derive_new::new;
use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, WorkgroupCount,
},
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
gpu::{BindGroupLayoutDescriptor, WorkgroupCount},
rvec, wgc, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec, StorageView,
Tensor,
};

#[derive(new, Debug, Clone)]
Expand Down
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice};
use crate::{
ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable,
GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType,
TensorId,
ops::*, rvec, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, GPUBuffer,
Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId,
};
use crate::{BinaryOp, LazyOp};

Expand Down

0 comments on commit 0df9377

Please sign in to comment.