Skip to content

Commit

Permalink
Small opt
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 26, 2024
1 parent 358dceb commit 45dc0f6
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
25 changes: 13 additions & 12 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ use metal_rs::{
ComputePipelineState, Device, MTLResourceOptions,
};

use luminal::{
op::{InputTensor, Operator},
prelude::{
petgraph::{visit::EdgeRef, Direction},
*,
},
use luminal::prelude::{
petgraph::{visit::EdgeRef, Direction},
*,
};

use crate::{
Expand Down Expand Up @@ -384,11 +381,15 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
*subexp = re
.replace_all(
subexp,
&format!(
"({} != 0 ? (float)input{i}[{}] : 0.0)$1",
expr_to_metal_string(val_exp.clone()),
expr_to_metal_string(ind_exp.clone())
),
&if *val_exp != true {
format!(
"({} != 0 ? (float)input{i}[{}] : 0.0)$1",
expr_to_metal_string(val_exp),
expr_to_metal_string(ind_exp)
)
} else {
format!("(float)input{i}[{}]$1", expr_to_metal_string(ind_exp))
},
)
.to_string();
}
Expand All @@ -410,7 +411,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
if val_exp != true {
*subexp = format!(
"(({} != 0) ? {subexp} : 0.0)",
expr_to_metal_string(val_exp)
expr_to_metal_string(&val_exp)
);
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/luminal_metal/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker], offset: usize) -> (Vec<char>,
)
}

fn expr_to_metal_string(expr: BigExpression) -> String {
fn expr_to_metal_string(expr: &BigExpression) -> String {
let mut symbols = vec![];
for term in expr.terms {
for term in expr.terms.clone() {
let new_symbol = match term {
Term::Num(n) => n.to_string(),
Term::Var(c) => {
Expand Down Expand Up @@ -440,8 +440,8 @@ fn expr_to_metal_string(expr: BigExpression) -> String {

fn get_idx_valid_exps(shape: ShapeTracker) -> (String, String) {
(
expr_to_metal_string(shape.index_expression()),
expr_to_metal_string(shape.valid_expression()),
expr_to_metal_string(&shape.index_expression()),
expr_to_metal_string(&shape.valid_expression()),
)
}

Expand Down
5 changes: 4 additions & 1 deletion crates/luminal_symbolic/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ impl<S: ExpressionStorage> GenericExpression<S> {
/// Maximum
pub fn max<E: Into<Self>>(self, rhs: E) -> Self {
let mut rhs = rhs.into();
if rhs == self {
if rhs == self || rhs == 0 {
return self;
}
if self == 0 {
return rhs;
}
rhs.terms.extend(self.terms);
rhs.terms.push(Term::Max);
rhs.simplify()
Expand Down
5 changes: 0 additions & 5 deletions src/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@ impl ShapeTracker {
let mut ind_expr = BigExpression::from(0); // The final index expression
let mut current_elem_size = BigExpression::from(1); // Keep track of the size of each element of the current dim (last dim elem size: 1)

// For combined dims
// divide by last dims (smallest) element size
// mod by combined dim size
// multiply by last dims (smallest) stride

// Loop through all dims in reverse order
for i in shape.indexes.into_iter().rev() {
// Get logical dimension size with padding and mask
Expand Down

0 comments on commit 45dc0f6

Please sign in to comment.