From f5c9f6d56b04fe7fd4f443872b8885036971d4ff Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Wed, 5 Jun 2024 14:11:31 -0500 Subject: [PATCH] StorageBufferCompiler causing segfault --- crates/luminal_metal/Cargo.toml | 2 +- crates/luminal_metal/src/elementwise_fusion.rs | 13 ++++++++++--- crates/luminal_metal/src/lib.rs | 2 ++ crates/luminal_metal/src/storage_buffer.rs | 7 +++---- examples/whisper/src/main.rs | 5 ++++- src/shape/symbolic.rs | 2 ++ 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/crates/luminal_metal/Cargo.toml b/crates/luminal_metal/Cargo.toml index b1b4da58..696497ce 100644 --- a/crates/luminal_metal/Cargo.toml +++ b/crates/luminal_metal/Cargo.toml @@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0" [dependencies] itertools = "0.12.1" luminal = { path = "../.." } -metal-rs = { version = "0.27.0", package = "metal", features = ["mps"] } +metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] } num-traits = "0.2.18" regex = "1.10.4" rustc-hash = "1.1.0" diff --git a/crates/luminal_metal/src/elementwise_fusion.rs b/crates/luminal_metal/src/elementwise_fusion.rs index c65b908f..020c075f 100644 --- a/crates/luminal_metal/src/elementwise_fusion.rs +++ b/crates/luminal_metal/src/elementwise_fusion.rs @@ -229,7 +229,10 @@ impl Compiler for ElementwiseFusionCompiler { .sorted_by_key(|(i, _, _)| *i) .map(|(_, _, s)| s) .collect::>(), - ); + ) + .into_iter() + .map(|s| s.simplify()) + .collect(); let new_op = graph .add_op(FusedElementwiseOp:: { kernel_str: "".to_string(), @@ -283,7 +286,10 @@ impl Compiler for ElementwiseFusionCompiler { let output_buffer_sizes = graph .node_custom::(op, "metal", ()) .unwrap() - .output_buffer_sizes(&input_shapes); + .output_buffer_sizes(&input_shapes) + .into_iter() + .map(|s| s.simplify()) + .collect(); let new_op = graph .add_op(FusedElementwiseOp:: { kernel_str: "".to_string(), @@ -317,7 +323,7 @@ impl Compiler for ElementwiseFusionCompiler { let op = graph.get_op_mut::>(fused_op); // Stack index expressions and replace them in the subexpressions // Track all shapes used, will pull dyn dims from these - op.pre_compile(inputs, &mut input_regexes, &intermediate_match); + op.pre_compile(inputs, &mut input_regexes, &intermediate_match, fused_op); op.compile(&device); } } @@ -343,6 +349,7 @@ impl FusedElementwiseOp { input_shapes: Vec, input_regexes: &mut FxHashMap, intermediate_match: &Regex, + node: NodeIndex, ) { let mut subexpressions = self.subexpressions.clone(); let shapes_used = subexpressions diff --git a/crates/luminal_metal/src/lib.rs b/crates/luminal_metal/src/lib.rs index 3ef8d9c0..caf73df2 100644 --- a/crates/luminal_metal/src/lib.rs +++ b/crates/luminal_metal/src/lib.rs @@ -18,6 +18,8 @@ pub mod quantized; pub mod storage_buffer; pub mod unary; +pub use objc::rc::autoreleasepool; + use itertools::Itertools; use metal_rs::*; use prim::MetalConstant; diff --git a/crates/luminal_metal/src/storage_buffer.rs b/crates/luminal_metal/src/storage_buffer.rs index cab91b1f..69e12a30 100644 --- a/crates/luminal_metal/src/storage_buffer.rs +++ b/crates/luminal_metal/src/storage_buffer.rs @@ -221,13 +221,13 @@ impl Compiler for StorageBufferCompiler { for required_buffer in wrapper.output_buffer_sizes(&input_shapes) { // Allocate new buffer buffer_map.get_mut(node).unwrap().0.push(buffers.len()); - buffers.push(required_buffer); + buffers.push(required_buffer.simplify()); } // Assign intermediate buffers for required_buffer in wrapper.intermediate_buffer_sizes(&input_shapes) { // Allocate new buffer buffer_map.get_mut(node).unwrap().1.push(buffers.len()); - buffers.push(required_buffer); + buffers.push(required_buffer.simplify()); } } @@ -335,10 +335,9 @@ impl Operator for AllocateMetalBuffers { // while length < size { // length *= 2; // } - let length = size; *buffer = self .dev - .new_buffer(length, MTLResourceOptions::StorageModeShared); + .new_buffer(size, MTLResourceOptions::StorageModeShared); } } } diff --git a/examples/whisper/src/main.rs b/examples/whisper/src/main.rs index 9a4d5c63..4569a87e 100644 --- a/examples/whisper/src/main.rs +++ b/examples/whisper/src/main.rs @@ -75,7 +75,10 @@ fn main() { ( GenericCompiler::default(), #[cfg(feature = "metal")] - luminal_metal::MetalCompiler::::default(), + ( + luminal_metal::MetalCompilerPreBuffer::::default(), + luminal_metal::command_buffer::CommandBufferCompiler, // For some reason storage buffer causes a segfault on decoder + ), #[cfg(feature = "cuda")] luminal_cuda::CudaCompiler::::default(), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index 4286c401..c734b40c 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -931,6 +931,8 @@ fn make_rules() -> Vec { rewrite!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")), // Other rewrite!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"), + rewrite!("distribute-max"; "(* ?a (max ?b ?c))" => "(max (* ?a ?b) (* ?a ?c))"), + rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"), rewrite!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"), ] }