From 862af13096628872f513d1830122f3ae2b64db1e Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 26 Apr 2024 10:26:44 -0500 Subject: [PATCH] Fixed metal --- .../luminal_metal/src/elementwise_fusion.rs | 6 ++-- crates/luminal_nn/src/convolution.rs | 34 ++++++++----------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/crates/luminal_metal/src/elementwise_fusion.rs b/crates/luminal_metal/src/elementwise_fusion.rs index be78287e..dca20065 100644 --- a/crates/luminal_metal/src/elementwise_fusion.rs +++ b/crates/luminal_metal/src/elementwise_fusion.rs @@ -407,7 +407,7 @@ impl Compiler for ElementwiseFusionCompiler { }, ) .0; - if val_exp != true.into() { + if val_exp != true { *subexp = format!( "(({} != 0) ? {subexp} : 0.0)", expr_to_metal_string(val_exp) @@ -875,8 +875,8 @@ mod tests { .permute::<_, Axes4<0, 2, 1, 3>>(); // Rotary embed queries and keys - let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().into()); - let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().into()); + let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::const_size().big()); + let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::const_size().big()); // Add KV cache let (keys, values) = ( diff --git a/crates/luminal_nn/src/convolution.rs b/crates/luminal_nn/src/convolution.rs index 3cfd8565..29fce529 100644 --- a/crates/luminal_nn/src/convolution.rs +++ b/crates/luminal_nn/src/convolution.rs @@ -23,18 +23,15 @@ impl< for Conv1D { fn initialize(cx: &mut Graph) -> Self { - let conv = Self { - weight: cx.named_tensor("Weight"), - }; - // Init weight as uniform(-1, 1) let mut rng = thread_rng(); - conv.weight.set( - (0..(CHANNELS_IN * CHANNELS_OUT * KERNEL)) - .map(|_| rng.gen_range(-1_f32..1_f32)) - .collect::>(), - ); - conv + Self { + weight: cx.named_tensor("Weight").set( + (0..(CHANNELS_IN * CHANNELS_OUT * KERNEL)) + .map(|_| rng.gen_range(-1_f32..1_f32)) + .collect::>(), + ), + } } } @@ -118,18 +115,15 @@ impl< > { fn initialize(cx: &mut Graph) -> Self { - let conv = Self { - weight: cx.named_tensor("Weight"), - }; - // Init weight as uniform(-1, 1) let mut rng = thread_rng(); - conv.weight.set( - (0..(CHANNELS_IN * CHANNELS_OUT * KERNELX * KERNELY)) - .map(|_| rng.gen_range(-1_f32..1_f32)) - .collect::>(), - ); - conv + Self { + weight: cx.named_tensor("Weight").set( + (0..(CHANNELS_IN * CHANNELS_OUT * KERNELX * KERNELY)) + .map(|_| rng.gen_range(-1_f32..1_f32)) + .collect::>(), + ), + } } }