diff --git a/examples/mistral/main.rs b/examples/mistral/main.rs index b6bcfd4e..80d6353d 100644 --- a/examples/mistral/main.rs +++ b/examples/mistral/main.rs @@ -132,6 +132,8 @@ fn main() { transfer_data(&kv_cache, &mut cx1, &cache_src_set, &mut cx2); drop(cx1); + // cx2.display_set(&[NodeIndex::new(51), NodeIndex::new(85), NodeIndex::new(87)]); + // Decode loop let mut token_decode_times = vec![]; for _ in 0..tokens_to_generate { diff --git a/src/compilers/metal/binary.rs b/src/compilers/metal/binary.rs index 1d145162..1d170b4d 100644 --- a/src/compilers/metal/binary.rs +++ b/src/compilers/metal/binary.rs @@ -127,12 +127,28 @@ impl Operator for MetalSub { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -332,12 +348,28 @@ impl Operator for MetalEqual { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } diff --git a/src/compilers/metal/command_buffer.rs b/src/compilers/metal/command_buffer.rs index 14be2eb1..b1c89f1f 100644 --- a/src/compilers/metal/command_buffer.rs +++ b/src/compilers/metal/command_buffer.rs @@ -1,4 +1,5 @@ use std::{ + any::Any, cell::UnsafeCell, collections::{HashMap, HashSet}, fmt::Debug, @@ -28,12 +29,14 @@ impl Compiler for CommandBufferCompiler { let is_metal: HashSet = graph .graph .node_indices() + .collect::>() + .into_iter() .filter(|i| { graph .graph - .node_weight(*i) + .node_weight_mut(*i) .unwrap() - .custom("metal") + .custom("metal", Box::new(())) .is_some() }) .collect(); @@ -170,9 +173,9 @@ impl Compiler for CommandBufferCompiler { // Wrap node in MetalKernelOperation let wrapper = graph .graph - .node_weight(*node) + .node_weight_mut(*node) .unwrap() - .custom("metal") + .custom("metal", Box::new(())) .unwrap() .downcast::() .unwrap(); @@ -280,7 +283,7 @@ impl Operator for CommandBufferWrapper { } #[allow(clippy::arc_with_non_send_sync)] - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), diff --git a/src/compilers/metal/fp16/matmul.rs b/src/compilers/metal/fp16/matmul.rs index 0c8d3eca..03ead265 100644 --- a/src/compilers/metal/fp16/matmul.rs +++ b/src/compilers/metal/fp16/matmul.rs @@ -223,7 +223,7 @@ impl Operator for MetalVecMat { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -404,7 +404,7 @@ impl Operator for MetalMatmul2D { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -580,7 +580,7 @@ impl Operator for MetalBatchMatmul2D { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), diff --git a/src/compilers/metal/fp16/mean_reduce.rs b/src/compilers/metal/fp16/mean_reduce.rs index afe30738..691031af 100644 --- a/src/compilers/metal/fp16/mean_reduce.rs +++ b/src/compilers/metal/fp16/mean_reduce.rs @@ -147,13 +147,28 @@ impl Operator for MetalMeanReduce { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { #[allow(clippy::arc_with_non_send_sync)] return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + self.2.clone(), + self.1.clone(), + self.3, + input_shapes[0], + self.5, + ); + } + } None } } diff --git a/src/compilers/metal/fp16/mod.rs b/src/compilers/metal/fp16/mod.rs index 20e0a73e..fc51f44c 100644 --- a/src/compilers/metal/fp16/mod.rs +++ b/src/compilers/metal/fp16/mod.rs @@ -18,7 +18,7 @@ pub type MetalFp16Compiler = ( mean_reduce::MeanReduceCompiler, rms_norm::RMSNormCompiler, super::other::CopyCompiler, - // super::other::ContiguousElimination, + super::other::ContiguousElimination, super::command_buffer::CommandBufferCompiler, super::storage_buffer::StorageBufferCompiler, ); diff --git a/src/compilers/metal/fp16/other.rs b/src/compilers/metal/fp16/other.rs index 35543245..95836096 100644 --- a/src/compilers/metal/fp16/other.rs +++ b/src/compilers/metal/fp16/other.rs @@ -91,12 +91,16 @@ impl Operator for MetalCos { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } None } } @@ -258,12 +262,16 @@ impl Operator for MetalExp { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } None } } diff --git a/src/compilers/metal/fp16/rms_norm.rs b/src/compilers/metal/fp16/rms_norm.rs index 8b2991f4..04cbbe81 100644 --- a/src/compilers/metal/fp16/rms_norm.rs +++ b/src/compilers/metal/fp16/rms_norm.rs @@ -19,6 +19,7 @@ pub struct MetalRMSNorm( ComputePipelineState, // RMSNorm kernel Device, ShapeTracker, // Input shape + f32, // Epsilon *const HashMap, ); @@ -79,6 +80,7 @@ kernel void mkernel(device float *inp [[buffer(0)]], device half *x [[buffer(1)] compile_function(&rms_norm_code_name, &rms_norm_code, &dev), dev, inp_shape, + epsilon, dyn_map, ) } @@ -127,7 +129,7 @@ impl MetalKernel for MetalRMSNorm { encoder.set_int(3, front_size as u32); encoder.set_int(4, back_size as u32); encoder.set_int(5, dim_size as u32); - input_dyn_dims(&[self.3], unsafe { self.4.as_ref().unwrap() }, encoder, 6); + input_dyn_dims(&[self.3], unsafe { self.5.as_ref().unwrap() }, encoder, 6); encoder.dispatch_1d(meaned_elements); encoder.end_encoding(); @@ -141,7 +143,7 @@ impl MetalKernel for MetalRMSNorm { encoder.set_buffer(1, Some(inputs[0].0), 0); encoder.set_buffer(2, Some(output_buffers[0]), 0); encoder.set_int(3, inputs[0].1.n_elements().to_usize().unwrap() as u32); - input_dyn_dims(&[self.3], unsafe { self.4.as_ref().unwrap() }, encoder, 4); + input_dyn_dims(&[self.3], unsafe { self.5.as_ref().unwrap() }, encoder, 4); // Execute encoder.dispatch_1d(inputs[0].1.n_elements().to_usize().unwrap()); @@ -183,13 +185,22 @@ impl Operator for MetalRMSNorm { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { #[allow(clippy::arc_with_non_send_sync)] return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), ))))); } + // This op can accept non contiguous inputs + if key == "non_contiguous" { + return Some(Box::new(())); + } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new(self.4, self.2.clone(), input_shapes[0], self.5) + } + } None } } diff --git a/src/compilers/metal/other.rs b/src/compilers/metal/other.rs index 2ec19f75..640bdae0 100644 --- a/src/compilers/metal/other.rs +++ b/src/compilers/metal/other.rs @@ -195,7 +195,7 @@ impl Operator for MetalARange { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { #[allow(clippy::arc_with_non_send_sync)] return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( @@ -315,7 +315,7 @@ impl Compiler for ContiguousElimination { .ptr(&mut contig) .edge( SelectOp::new() - .check(|op, _| op.custom("non_contiguous").is_some()) + .check(|op, _| op.custom("non_contiguous", Box::new(())).is_some()) .ptr(&mut op), ); let mut selector = pattern.search(graph); @@ -330,12 +330,19 @@ impl Compiler for ContiguousElimination { continue; } // Shape going from contig to op + // let first_shape = graph + // .graph + // .edges_directed(contig, Direction::Incoming) + // .find_map(|e| e.weight().as_data()) + // .unwrap() + // .2; let second_shape = graph .graph .edges_connecting(contig, op) .find_map(|e| e.weight().as_data()) .unwrap() .2; + // Here we should check if second shape and first shape are mergeable instead of just checking if second_shape is contiguous if second_shape.is_contiguous() && !second_shape.is_sliced() && !second_shape.is_padded() @@ -354,7 +361,16 @@ impl Compiler for ContiguousElimination { source, ); graph.graph.remove_node(contig); - selector.clear_cached_results(); + let new_shapes = graph + .get_sources(op) + .into_iter() + .map(|(_, _, s)| s) + .collect::>(); + graph + .graph + .node_weight_mut(op) + .unwrap() + .custom("recompile_shapes", Box::new(new_shapes)); } } } diff --git a/src/compilers/metal/prim.rs b/src/compilers/metal/prim.rs index 2de3355a..102fbb0b 100644 --- a/src/compilers/metal/prim.rs +++ b/src/compilers/metal/prim.rs @@ -87,7 +87,7 @@ impl Operator for MetalCopyFromDevice { }] } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { // This op can accept non contiguous inputs if key == "non_contiguous" { return Some(Box::new(())); @@ -226,7 +226,7 @@ impl Operator for MetalContiguous { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -236,6 +236,11 @@ impl Operator for MetalContiguous { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new(input_shapes[0], self.1.clone(), &mut HashMap::new(), self.4) + } + } None } } @@ -325,7 +330,7 @@ impl Operator for MetalLog2 { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -423,7 +428,7 @@ impl Operator for MetalExp2 { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -521,7 +526,7 @@ impl Operator for MetalSin { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -618,7 +623,7 @@ impl Operator for MetalSqrt { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -715,7 +720,7 @@ impl Operator for MetalRecip { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, _: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -845,7 +850,7 @@ impl Operator for MetalAdd { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -855,6 +860,18 @@ impl Operator for MetalAdd { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -975,7 +992,7 @@ impl Operator for MetalMul { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -985,6 +1002,18 @@ impl Operator for MetalMul { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -1117,7 +1146,7 @@ impl Operator for MetalLessThan { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -1127,6 +1156,18 @@ impl Operator for MetalLessThan { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -1245,7 +1286,7 @@ impl Operator for MetalMod { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -1255,6 +1296,18 @@ impl Operator for MetalMod { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + input_shapes[1], + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -1398,7 +1451,7 @@ impl Operator for MetalSumReduce { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -1408,6 +1461,18 @@ impl Operator for MetalSumReduce { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + self.3, + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } @@ -1554,7 +1619,7 @@ impl Operator for MetalMaxReduce { }) } - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { if key == "metal" { return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new( self.clone(), @@ -1564,6 +1629,18 @@ impl Operator for MetalMaxReduce { if key == "non_contiguous" { return Some(Box::new(())); } + if key == "recompile_shapes" { + if let Some(input_shapes) = input.downcast_ref::>() { + *self = Self::new( + input_shapes[0], + self.3, + self.2.clone(), + self.1.clone(), + &mut HashMap::new(), + self.6, + ) + } + } None } } diff --git a/src/compilers/metal/storage_buffer.rs b/src/compilers/metal/storage_buffer.rs index 3da99ef9..8b33e562 100644 --- a/src/compilers/metal/storage_buffer.rs +++ b/src/compilers/metal/storage_buffer.rs @@ -86,12 +86,14 @@ impl Compiler for StorageBufferCompiler { .graph .node_indices() .filter(|n| !graph.no_delete.contains(n)) + .collect::>() + .into_iter() .filter_map(|n| { if let Some(Ok(wrapper)) = graph .graph - .node_weight(n) + .node_weight_mut(n) .unwrap() - .custom("metal") + .custom("metal", Box::new(())) .map(|n| n.downcast::()) { Some((n, wrapper)) @@ -99,6 +101,8 @@ impl Compiler for StorageBufferCompiler { None } }) + .collect::>() + .into_iter() .map(|(n, wrapper)| { let input_shapes = graph .get_sources(n) @@ -120,9 +124,9 @@ impl Compiler for StorageBufferCompiler { } let Some(Ok(wrapper)) = graph .graph - .node_weight(*node) + .node_weight_mut(*node) .unwrap() - .custom("metal") + .custom("metal", Box::new(())) .map(|e| e.downcast::()) else { continue; @@ -234,9 +238,9 @@ impl Compiler for StorageBufferCompiler { { let wrapper = graph .graph - .node_weight(node) + .node_weight_mut(node) .unwrap() - .custom("metal") + .custom("metal", Box::new(())) .unwrap() .downcast::() .unwrap(); diff --git a/src/core/compiler_utils.rs b/src/core/compiler_utils.rs index 70d261bc..26668bc3 100644 --- a/src/core/compiler_utils.rs +++ b/src/core/compiler_utils.rs @@ -497,7 +497,7 @@ type SelectionGraph = petgraph::Graph>; pub struct GraphSearch { selector: SelectionGraph, - graph: *const Graph, + graph: *mut Graph, to_return: Vec>, returned_anchors: HashSet, anchor: NodeIndex, @@ -506,18 +506,18 @@ pub struct GraphSearch { impl GraphSearch { pub fn next_match(&mut self) -> bool { // Look through graph for pattern from selector - let graph = unsafe { self.graph.as_ref().unwrap() }; + let graph = unsafe { self.graph.as_mut().unwrap() }; if self.to_return.is_empty() { // Replenish to_return let select_op = self.selector.node_weight(self.anchor).unwrap(); - for node in graph.graph.node_indices() { + for node in graph.graph.node_indices().collect::>() { if !self.returned_anchors.contains(&node) - && test_node(select_op, &graph.graph, node) + && test_node(select_op, &mut graph.graph, node) { // Backtrack to check if this is a match if let Some(mapping) = - backtrack_match(self.anchor, &self.selector, node, &graph.graph) + backtrack_match(self.anchor, &self.selector, node, &mut graph.graph) { self.to_return.push(mapping); } @@ -553,7 +553,7 @@ fn backtrack_match( select_node: NodeIndex, select_graph: &SelectionGraph, main_node: NodeIndex, - graph: &MainGraph, + graph: &mut MainGraph, ) -> Option> { // Dfs backward through both the selector graph and the main graph let mut mapping = HashMap::new(); @@ -626,22 +626,22 @@ fn test_node( fake, pointers: _, }: &SelectOp, - graph: &MainGraph, + graph: &mut MainGraph, graph_node: NodeIndex, ) -> bool { - let current_weight = graph.node_weight(graph_node).unwrap(); - // Test type - if let Some(ty) = type_id { - if current_weight.as_any().type_id() != *ty { - return false; - } - } let input_shapes = graph .edges_directed(graph_node, petgraph::Direction::Incoming) .filter_map(|e| e.weight().as_data()) .sorted_by_key(|e| e.0) .map(|e| e.2) .collect::>(); + let current_weight = graph.node_weight_mut(graph_node).unwrap(); + // Test type + if let Some(ty) = type_id { + if current_weight.as_any().type_id() != *ty { + return false; + } + } // Test shape if let Some(shape) = shape { @@ -692,7 +692,7 @@ fn test_node( // Run check if let Some(check) = check { - if !check(current_weight.as_ref(), &input_shapes) { + if !check(current_weight.as_mut(), &input_shapes) { return false; } } @@ -705,7 +705,7 @@ pub struct SelectOp { type_id: Option, /// Check constraint #[allow(clippy::type_complexity)] - check: Option bool>, + check: Option bool>, /// Shape constraint shape: Option>>, /// Fake constraint @@ -747,7 +747,7 @@ impl SelectOp { self } /// Constrain the op to a checking function - pub fn check(mut self, check: fn(&dyn Operator, &[ShapeTracker]) -> bool) -> Self { + pub fn check(mut self, check: fn(&mut dyn Operator, &[ShapeTracker]) -> bool) -> Self { self.check = Some(check); self } @@ -818,7 +818,7 @@ impl SelectEdge { Self::internal_new(a, Some(out), b) } - pub fn search(self, graph: &Graph) -> GraphSearch { + pub fn search(self, graph: &mut Graph) -> GraphSearch { let anchor = *toposort(&self.graph, None).unwrap().last().unwrap(); GraphSearch { to_return: vec![], @@ -852,7 +852,7 @@ mod tests { #[test] fn test_graph_selector() { - let cx = Graph::default(); + let mut cx = Graph::default(); // Exp -> Log or Log -> Exp let (mut exp, mut log) = (NodeIndex::default(), NodeIndex::default()); let (exp_select, log_select) = ( @@ -862,13 +862,13 @@ mod tests { let selector1 = log_select.clone().edge(exp_select.clone()); let selector2 = exp_select.edge(log_select); - assert!(!selector1.search(&cx).next_match() && !selector2.search(&cx).next_match()); + assert!(!selector1.search(&mut cx).next_match() && !selector2.search(&mut cx).next_match()); // Matmul let s = SelectOp::new() .ty::() .edge(SelectOp::new().ty::()); - assert!(!s.search(&cx).next_match()); + assert!(!s.search(&mut cx).next_match()); } } diff --git a/src/core/op.rs b/src/core/op.rs index 356df721..83ffff68 100644 --- a/src/core/op.rs +++ b/src/core/op.rs @@ -33,7 +33,7 @@ impl<'a> InputTensor<'a> { pub trait Operator: Debug + TraitObjEq { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec; #[allow(unused)] - fn custom(&self, key: &str) -> Option> { + fn custom(&mut self, key: &str, input: Box) -> Option> { None } }