diff --git a/.github/workflows/ratbot.yml b/.github/workflows/ratbot.yml index 4e5ca8d6..0a79de7c 100644 --- a/.github/workflows/ratbot.yml +++ b/.github/workflows/ratbot.yml @@ -23,13 +23,9 @@ jobs: uses: actions/github-script@v6 with: script: | - const codeReport = ` - \`\`\`\n - ``` - ${{ steps.scc.outputs.scc }} - ``` - \`\`\` - `; + const codeReport = "\\`\\`\\`\\n" + + ${{ steps.scc.outputs.scc }} + + "\\n\\`\\`\\`"; github.rest.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, diff --git a/crates/ratchet-core/src/compiled_op.rs b/crates/ratchet-core/src/compiled_op.rs index 77311911..6f1d0cd7 100644 --- a/crates/ratchet-core/src/compiled_op.rs +++ b/crates/ratchet-core/src/compiled_op.rs @@ -2,7 +2,7 @@ use crate::gpu::{ BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice, WorkgroupCount, }; -use crate::{drvec, rvec, RVec, Tensor}; +use crate::{drvec, rvec, OperationError, RVec, Tensor}; use derive_new::new; use wgpu::DynamicOffset; @@ -20,14 +20,13 @@ pub struct CompiledOp { impl CompiledOp { const MAX_BINDINGS_PER_GROUP: usize = 4; - //TODO: Should return a Result pub fn create_storage_bind_groups( srcs: &[&Tensor], dst: &Tensor, bind_group_layouts: RVec, device: &WgpuDevice, inplace: bool, - ) -> RVec { + ) -> Result, OperationError> { let mut bind_group_entries = drvec![]; for tensor in srcs.iter() { @@ -44,12 +43,10 @@ impl CompiledOp { let entries = bind_group_entries[group_range].into(); let layout = *bind_group_layout; - let bind_group = device - .get_or_create_bind_group(&BindGroupDescriptor { entries, layout }) - .unwrap(); - storage_groups.push(bind_group); + let bg = device.get_or_create_bind_group(&BindGroupDescriptor { entries, layout })?; + storage_groups.push(bg); } - storage_groups + Ok(storage_groups) } /// Determines which bindings belong to which bind group diff --git a/crates/ratchet-core/src/dtype.rs b/crates/ratchet-core/src/dtype.rs index e8be917c..f405aefb 100644 --- a/crates/ratchet-core/src/dtype.rs +++ b/crates/ratchet-core/src/dtype.rs @@ -31,6 +31,7 @@ impl DType { } } + //TODO: use a different method, total_bytes won't work with 256 byte padding pub fn segments(&self, total_bytes: usize) -> RVec { match self { DType::WQ8 => { diff --git a/crates/ratchet-core/src/gpu/buffer_allocator.rs b/crates/ratchet-core/src/gpu/buffer_allocator.rs index 96f24ce2..c7602587 100644 --- a/crates/ratchet-core/src/gpu/buffer_allocator.rs +++ b/crates/ratchet-core/src/gpu/buffer_allocator.rs @@ -119,18 +119,24 @@ impl BufferAllocator { /// from the actual source (i.e the first non-inplace operation) /// /// On what conditions do we terminate the upward traversal? - /// 1. We reach a constant - /// 2. We reach an operation that does not support inplace - /// 3. We reach an operation that has more than one consumer - /// 4. We reach an operation that has more than one source - fn traverse_upwards_for_inplace(source: &Tensor) -> &Tensor { + /// 1. We reach an operation that does not support inplace + /// 2. We reach an operation that has more than one consumer + /// 3. We reach an operation that has more than one source + fn determine_tensor_source<'a>(source: &'a Tensor, execution_order: &[Tensor]) -> &'a Tensor { let mut true_source = source; loop { - let is_const = true_source.op().is_const(); let cant_inplace = !true_source.op().supports_inplace(); let multiple_sources = true_source.op().srcs().len() > 1; - let multiple_consumers = false; //TODO: implement - if cant_inplace || multiple_sources || multiple_consumers || is_const { + let ts_index = execution_order + .iter() + .position(|t| t.id() == true_source.id()) + .unwrap(); + let multiple_consumers = execution_order[ts_index + 1..] + .iter() + .filter(|t| t.op().srcs().contains(&true_source)) + .count() + > 1; + if cant_inplace || multiple_sources || multiple_consumers { break; } @@ -169,7 +175,7 @@ impl BufferAllocator { // If the current tensor is an inplace operation, // we traverse upwards until we find a non-inplace operation. for source in t.op().srcs() { - let true_source = Self::traverse_upwards_for_inplace(source); + let true_source = Self::determine_tensor_source(source, execution_order); assignments.entry(true_source.id()).or_insert_with(|| { self.graph_allocate( BufferDescriptor::new( @@ -194,7 +200,7 @@ impl BufferAllocator { //We know we need an allocation for the output. //We traverse upwards until we find the first non-inplace operation, and use it's buffer. let output = execution_order.last().unwrap(); - let output_source = Self::traverse_upwards_for_inplace(output); + let output_source = Self::determine_tensor_source(output, execution_order); //If output source is allocated, we can use it's buffer //Otherwise, we need to allocate a new buffer diff --git a/crates/ratchet-core/src/gpu/uniform.rs b/crates/ratchet-core/src/gpu/uniform.rs index dbe4b2c5..25e4466d 100644 --- a/crates/ratchet-core/src/gpu/uniform.rs +++ b/crates/ratchet-core/src/gpu/uniform.rs @@ -2,7 +2,7 @@ use std::num::NonZeroU64; use crate::{ gpu::{BindGroupEntry, BindGroupLayoutDescriptor}, - rvec, + rvec, OperationError, }; use super::{BindGroupDescriptor, GpuBindGroup, PooledGPUBuffer, WgpuDevice}; @@ -32,26 +32,20 @@ impl CpuUniform { } /// Consumes the CPU repr of the uniform buffer and writes to the GPU. - pub(crate) fn into_gpu(self, device: &WgpuDevice) -> GpuUniform { - let uniform_buf = device.create_uniform_init(self); - let bind_group_layout = device - .get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform()) - .unwrap(); - let bind_group = device - .get_or_create_bind_group(&BindGroupDescriptor { - entries: rvec![BindGroupEntry { - handle: uniform_buf.handle, - offset: 0, - size: NonZeroU64::new(uniform_buf.size()), - }], - layout: bind_group_layout, - }) - .unwrap(); + pub(crate) fn into_gpu(self, device: &WgpuDevice) -> Result { + let buf = device.create_uniform_init(self); + let layout = + device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?; + let bind_group = device.get_or_create_bind_group(&BindGroupDescriptor { + entries: rvec![BindGroupEntry { + handle: buf.handle, + offset: 0, + size: NonZeroU64::new(buf.size()), + }], + layout, + })?; - GpuUniform { - buf: uniform_buf, - bind_group, - } + Ok(GpuUniform { buf, bind_group }) } } diff --git a/crates/ratchet-core/src/op.rs b/crates/ratchet-core/src/op.rs index 1e3fc3f5..0064ac65 100644 --- a/crates/ratchet-core/src/op.rs +++ b/crates/ratchet-core/src/op.rs @@ -7,8 +7,8 @@ use crate::gpu::{CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN}; use crate::{rvec, Binary, CompiledOp, InvariantError, Matmul, RVec, Softmax, StorageView, Tensor}; #[derive(Clone, Debug)] +#[non_exhaustive] pub enum LazyOp { - Dummy(Tensor), Matmul(Matmul), Binary(Binary), Softmax(Softmax), @@ -21,7 +21,6 @@ impl LazyOp { LazyOp::Binary(b) => b.srcs(), LazyOp::Matmul(m) => m.srcs(), LazyOp::Softmax(s) => s.srcs(), - LazyOp::Dummy(t) => rvec![t], LazyOp::Const => rvec![], //end of the line kid _ => unimplemented!(), } @@ -32,7 +31,7 @@ impl LazyOp { LazyOp::Binary(b) => b.supports_inplace(), LazyOp::Matmul(m) => m.supports_inplace(), LazyOp::Softmax(s) => s.supports_inplace(), - LazyOp::Const => true, + LazyOp::Const => false, _ => false, } } diff --git a/crates/ratchet-core/src/ops/binary.rs b/crates/ratchet-core/src/ops/binary.rs index a81d81d5..2bfe3a39 100644 --- a/crates/ratchet-core/src/ops/binary.rs +++ b/crates/ratchet-core/src/ops/binary.rs @@ -97,7 +97,7 @@ impl Operation for Binary { rvec![storage_layout], device, false, - ); + )?; Ok(CompiledOp::new( pipeline_handle, diff --git a/crates/ratchet-core/src/ops/matmul.rs b/crates/ratchet-core/src/ops/matmul.rs index b6866df2..8e9e91b3 100644 --- a/crates/ratchet-core/src/ops/matmul.rs +++ b/crates/ratchet-core/src/ops/matmul.rs @@ -179,6 +179,8 @@ pub struct Matmul { rhs: Tensor, } +impl Matmul {} + #[allow(clippy::too_many_arguments)] #[derive(Debug, Clone, ShaderType)] pub struct MatmulMeta { @@ -297,7 +299,7 @@ impl Operation for Matmul { rvec![storage_layout], device, false, - ); + )?; Ok(CompiledOp::new( pipeline_handle, @@ -308,6 +310,9 @@ impl Operation for Matmul { } fn infer_output(&self, srcs: &[&Tensor]) -> Result { + let (_a, _b) = (srcs[0], srcs[1]); + //let c_shape = Matmul::compute_output_shape(a.clone(), b.clone()).unwrap(); + //TODO: THIS IS WRONG 🚨 Ok(srcs[0].view().clone()) } diff --git a/crates/ratchet-core/src/ops/softmax.rs b/crates/ratchet-core/src/ops/softmax.rs index 7435eb73..de536aae 100644 --- a/crates/ratchet-core/src/ops/softmax.rs +++ b/crates/ratchet-core/src/ops/softmax.rs @@ -87,7 +87,7 @@ impl Operation for Softmax { rvec![storage_layout], device, can_inplace, - ); + )?; Ok(CompiledOp::new( pipeline_handle, @@ -98,7 +98,6 @@ impl Operation for Softmax { } fn infer_output(&self, srcs: &[&Tensor]) -> Result { - //TODO: FIX Ok(srcs[0].view().clone()) } diff --git a/crates/ratchet-core/src/tensor.rs b/crates/ratchet-core/src/tensor.rs index 1149fcb3..b6718713 100644 --- a/crates/ratchet-core/src/tensor.rs +++ b/crates/ratchet-core/src/tensor.rs @@ -1,6 +1,6 @@ use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice}; use crate::{ - ops::*, rvec, shape, strides, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, + ops::*, rvec, shape, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable, GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType, TensorId, }; @@ -55,15 +55,6 @@ impl Tensor { Self::new(op, meta, None, device) } - pub fn dummy(src: Tensor) -> Self { - Self::new( - LazyOp::Dummy(src), - StorageView::new(shape![], DType::F32, Strides::default()), - None, - Device::CPU, - ) - } - fn update_storage(&self, storage: Storage) { *self.inner.storage.write() = Some(storage); } @@ -270,12 +261,12 @@ impl Tensor { let handle = gpu_buf.inner().handle; let segments = self.dt().segments(gpu_buf.inner().size() as usize); segments.iter().fold(rvec![], |mut entries, segment| { - let entry = BindGroupEntry { + let (offset, size) = (segment.offset, segment.size); + entries.push(BindGroupEntry { handle, - offset: segment.offset, - size: segment.size, - }; - entries.push(entry); + offset, + size, + }); entries }) } @@ -298,19 +289,14 @@ impl Tensor { if visited.contains(&tensor) { continue; } - match &tensor.inner.op { - LazyOp::Const => {} - LazyOp::Binary(b) => { - stack.extend(b.srcs().into_iter().cloned()); - } - LazyOp::Matmul(m) => { - stack.extend(m.srcs().into_iter().cloned()); - } - LazyOp::Softmax(s) => { - stack.extend(s.srcs().into_iter().cloned()); - } + let srcs = match &tensor.inner.op { + LazyOp::Const => rvec![], + LazyOp::Binary(b) => b.srcs(), + LazyOp::Matmul(m) => m.srcs(), + LazyOp::Softmax(s) => s.srcs(), _ => unimplemented!(), - } + }; + stack.extend(srcs.into_iter().cloned()); visited.push(tensor); } visited.reverse(); @@ -337,7 +323,6 @@ impl Tensor { let device = self.device().try_gpu()?; let execution_order = self.execution_order(); - println!("EXECUTION ORDER: \n{:#?}", execution_order); let mut compiled_ops = Vec::with_capacity(execution_order.len()); let allocations = device.allocate_cfg(&execution_order, device)?; @@ -353,6 +338,7 @@ impl Tensor { t.update_storage(Storage::GPU(storage)); } + //Can inplace && only 1 consumer let can_inplace = t.op().supports_inplace() && execution_order[tix + 1..] .iter() @@ -364,7 +350,7 @@ impl Tensor { compiled_ops.push(compiled_op); } } - let executable = Executable::new(compiled_ops, uniform.into_gpu(device)); + let executable = Executable::new(compiled_ops, uniform.into_gpu(device)?); let index = executable.dispatch_operations(device).unwrap(); device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index)); Ok(()) @@ -470,6 +456,7 @@ impl Tensor { } } +#[derive(Default)] struct CloseStats { total_error: f32, max_abs_error: f32, @@ -483,13 +470,9 @@ struct CloseStats { impl CloseStats { fn new(atol: f32, rtol: f32) -> Self { Self { - total_error: 0.0, - max_abs_error: 0.0, - max_abs_error_idxs: None, - element_count: 0, - fail_count: 0, atol, rtol, + ..Default::default() } }