Skip to content

Commit

Permalink
Enabled elementwise on metal prims
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 16, 2024
1 parent 69c207b commit 8f2d13d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
6 changes: 6 additions & 0 deletions src/compilers/metal/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ impl<T: MetalFloat> Operator for MetalSub<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 - input1".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down Expand Up @@ -337,6 +340,9 @@ impl<T: MetalFloat> Operator for MetalEqual<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 == input1 ? 1.0 : 0.0".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down
5 changes: 3 additions & 2 deletions src/compilers/metal/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
// Render a into b as input to_input
fused_op.equation = fused_op
.equation
.replace(&format!("input{to_input}"), &a_equation);
.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input + 1..n_edges {
fused_op.equation = fused_op
Expand All @@ -130,7 +130,8 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
let mut b_equation = graph
.node_custom::<String, _>(b, "elementwise", ())
.unwrap();
b_equation = b_equation.replace(&format!("input{to_input}"), &a_equation);
b_equation =
b_equation.replace(&format!("input{to_input}"), &format!("({a_equation})"));
// Since we are removing the input from a, we must decrement all inputs larger than that
for i in to_input + 1..n_edges {
b_equation =
Expand Down
11 changes: 11 additions & 0 deletions src/compilers/metal/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ impl<T: MetalFloat> Operator for MetalExp<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("exp(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -588,6 +591,11 @@ impl<T: MetalFloat> Operator for MetalSwish<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new(
"input0 * (1.0h / (1.0h + exp(-input0)))".to_string(),
));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -777,6 +785,9 @@ impl<T: MetalFloat> Operator for MetalCos<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("cos(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down
21 changes: 21 additions & 0 deletions src/compilers/metal/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ impl<T: MetalFloat> Operator for MetalLog2<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("log2(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -625,6 +628,9 @@ impl<T: MetalFloat> Operator for MetalSqrt<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("sqrt(input0)".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -717,6 +723,9 @@ impl<T: MetalFloat> Operator for MetalRecip<T> {
self.clone(),
)))));
}
if key == "elementwise" {
return Some(Box::new("1.0 / input0".to_string()));
}
// This op can accept non contiguous inputs
if key == "non_contiguous" {
return Some(Box::new(()));
Expand Down Expand Up @@ -839,6 +848,9 @@ impl<T: MetalFloat> Operator for MetalAdd<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 + input1".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down Expand Up @@ -969,6 +981,9 @@ impl<T: MetalFloat> Operator for MetalMul<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 * input1".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down Expand Up @@ -1111,6 +1126,9 @@ impl<T: MetalFloat> Operator for MetalLessThan<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("input0 < input1 ? 1.0 : 0.0".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down Expand Up @@ -1239,6 +1257,9 @@ impl<T: MetalFloat> Operator for MetalMod<T> {
if key == "non_contiguous" {
return Some(Box::new(()));
}
if key == "elementwise" {
return Some(Box::new("fmod(input0, input1)".to_string()));
}
if key == "recompile_shapes" {
if let Some(input_shapes) = input.downcast_ref::<Vec<ShapeTracker>>() {
*self = Self::new(
Expand Down

0 comments on commit 8f2d13d

Please sign in to comment.