Skip to content

Commit

Permalink
StorageBufferCompiler causing segfault
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jun 5, 2024
1 parent 3c47c9f commit f5c9f6d
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion crates/luminal_metal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ license = "MIT OR Apache-2.0"
[dependencies]
itertools = "0.12.1"
luminal = { path = "../.." }
metal-rs = { version = "0.27.0", package = "metal", features = ["mps"] }
metal-rs = { version = "0.28.0", package = "metal", features = ["mps"] }
num-traits = "0.2.18"
regex = "1.10.4"
rustc-hash = "1.1.0"
Expand Down
13 changes: 10 additions & 3 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.sorted_by_key(|(i, _, _)| *i)
.map(|(_, _, s)| s)
.collect::<Vec<_>>(),
);
)
.into_iter()
.map(|s| s.simplify())
.collect();
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
Expand Down Expand Up @@ -283,7 +286,10 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let output_buffer_sizes = graph
.node_custom::<MetalKernelWrapper, _>(op, "metal", ())
.unwrap()
.output_buffer_sizes(&input_shapes);
.output_buffer_sizes(&input_shapes)
.into_iter()
.map(|s| s.simplify())
.collect();
let new_op = graph
.add_op(FusedElementwiseOp::<T> {
kernel_str: "".to_string(),
Expand Down Expand Up @@ -317,7 +323,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let op = graph.get_op_mut::<FusedElementwiseOp<T>>(fused_op);
// Stack index expressions and replace them in the subexpressions
// Track all shapes used, will pull dyn dims from these
op.pre_compile(inputs, &mut input_regexes, &intermediate_match);
op.pre_compile(inputs, &mut input_regexes, &intermediate_match, fused_op);
op.compile(&device);
}
}
Expand All @@ -343,6 +349,7 @@ impl<T: MetalFloat> FusedElementwiseOp<T> {
input_shapes: Vec<ShapeTracker>,
input_regexes: &mut FxHashMap<usize, Regex>,
intermediate_match: &Regex,
node: NodeIndex,
) {
let mut subexpressions = self.subexpressions.clone();
let shapes_used = subexpressions
Expand Down
2 changes: 2 additions & 0 deletions crates/luminal_metal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub mod quantized;
pub mod storage_buffer;
pub mod unary;

pub use objc::rc::autoreleasepool;

use itertools::Itertools;
use metal_rs::*;
use prim::MetalConstant;
Expand Down
7 changes: 3 additions & 4 deletions crates/luminal_metal/src/storage_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ impl Compiler for StorageBufferCompiler {
for required_buffer in wrapper.output_buffer_sizes(&input_shapes) {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().0.push(buffers.len());
buffers.push(required_buffer);
buffers.push(required_buffer.simplify());
}
// Assign intermediate buffers
for required_buffer in wrapper.intermediate_buffer_sizes(&input_shapes) {
// Allocate new buffer
buffer_map.get_mut(node).unwrap().1.push(buffers.len());
buffers.push(required_buffer);
buffers.push(required_buffer.simplify());
}
}

Expand Down Expand Up @@ -335,10 +335,9 @@ impl Operator for AllocateMetalBuffers {
// while length < size {
// length *= 2;
// }
let length = size;
*buffer = self
.dev
.new_buffer(length, MTLResourceOptions::StorageModeShared);
.new_buffer(size, MTLResourceOptions::StorageModeShared);
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion examples/whisper/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ fn main() {
(
GenericCompiler::default(),
#[cfg(feature = "metal")]
luminal_metal::MetalCompiler::<f16>::default(),
(
luminal_metal::MetalCompilerPreBuffer::<f16>::default(),
luminal_metal::command_buffer::CommandBufferCompiler, // For some reason storage buffer causes a segfault on decoder
),
#[cfg(feature = "cuda")]
luminal_cuda::CudaCompiler::<f32>::default(),
#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
Expand Down
2 changes: 2 additions & 0 deletions src/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,8 @@ fn make_rules() -> Vec<Rewrite> {
rewrite!("cancel-div"; "(/ ?a ?a)" => "1" if is_not_zero("?a")),
// Other
rewrite!("distribute"; "(* ?a (+ ?b ?c))" => "(+ (* ?a ?b) (* ?a ?c))"),
rewrite!("distribute-max"; "(* ?a (max ?b ?c))" => "(max (* ?a ?b) (* ?a ?c))"),
rewrite!("distribute-min"; "(* ?a (min ?b ?c))" => "(min (* ?a ?b) (* ?a ?c))"),
rewrite!("factor" ; "(+ (* ?a ?b) (* ?a ?c))" => "(* ?a (+ ?b ?c))"),
]
}
Expand Down

0 comments on commit f5c9f6d

Please sign in to comment.