Skip to content

Commit

Permalink
Merge pull request #11 from TheSeamau5/fix_swish
Browse files Browse the repository at this point in the history
Fix swish
  • Loading branch information
jafioti authored Jan 10, 2024
2 parents 4dd7cd7 + ff1da67 commit fa67608
Showing 1 changed file with 54 additions and 48 deletions.
102 changes: 54 additions & 48 deletions src/compilers/metal/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use petgraph::{stable_graph::NodeIndex, visit::EdgeRef, Direction};
use crate::{
compilers::metal::{prim::*, *},
constant_select_op,
op::Operator,
op::{ConstantValue, Operator},
prelude::*,
};

Expand Down Expand Up @@ -776,21 +776,7 @@ impl<T: MetalFloat> Compiler for MetalSwishCompiler<T> {
let dev = Device::system_default().unwrap();
let queue = dev.new_command_queue();
// Look for the swish pattern
let (
mut x,
mut neg_one,
mut mul1,
mut mul2,
mut mul3,
mut exp,
mut one,
mut one2,
mut add,
mut recip,
) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
let (mut neg_one, mut mul1, mut mul2, mut exp, mut one, mut add, mut recip) = (
NodeIndex::default(),
NodeIndex::default(),
NodeIndex::default(),
Expand All @@ -799,55 +785,75 @@ impl<T: MetalFloat> Compiler for MetalSwishCompiler<T> {
NodeIndex::default(),
NodeIndex::default(),
);
let mut searcher = constant_select_op!(1.0, T)
.ptr(&mut one)
.edge(
constant_select_op!(1.0, T)
.ptr(&mut one2)
.edge(
SelectOp::new()
.ptr(&mut x)
.edge(
constant_select_op!(-1.0, T)
.ptr(&mut neg_one)
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul1)),
)
.edge(SelectOp::new().ty::<MetalExp<T>>().ptr(&mut exp))
.edge(SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add)),
)
.edge(SelectOp::new().ty::<MetalRecip<T>>().ptr(&mut recip))
.edge(SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul2)),
)
.edge(SelectOp::new().ty::<MetalMul<f16>>().ptr(&mut mul3))

let neg_one_node = constant_select_op!(-1.0, T).ptr(&mut neg_one);
let mul1_node = SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul1);
let mul2_node = SelectOp::new().ty::<MetalMul<T>>().ptr(&mut mul2);
let exp_node = SelectOp::new().ty::<MetalExp<T>>().ptr(&mut exp);
let recip_node = SelectOp::new().ty::<MetalRecip<T>>().ptr(&mut recip);
let add_node = SelectOp::new().ty::<MetalAdd<T>>().ptr(&mut add);
let mut searcher = neg_one_node
.edge(mul1_node)
.edge(exp_node)
.edge(add_node)
.edge(recip_node)
.edge(mul2_node)
.search(graph);

while searcher.next_match() {
if check_no_delete(graph, &[neg_one, mul1, mul2, mul3, exp, one, add, recip])
|| one != one2
{
if check_no_delete(graph, &[neg_one, mul1, mul2, exp, one, add, recip]) {
// An intermediate node can't be deleted
continue;
}

// Check the if input to add is one
let add_sources = graph.get_sources(add);
let (src1_index, _, _) = add_sources[0];
let (src2_index, _, _) = add_sources[1];

let src_index = if src1_index == exp {
src2_index
} else {
src1_index
};

let test_op = graph.graph.node_weight(src_index).unwrap();

// If test op is not 1, we continue
let test_op = test_op.as_any().downcast_ref::<MetalConstant<T>>();
if let Some(test_op) = test_op {
if test_op.0 != ConstantValue::Float(1.0) {
continue;
} else {
one = src_index;
}
} else {
continue;
}

// Now we look for the input
let mul1_sources = graph.get_sources(mul1);
let (src1_index, _, shape1) = mul1_sources[0];
let (src2_index, _, shape2) = mul1_sources[1];

let (src_index, shape) = if src1_index == neg_one {
(src2_index, shape2)
} else {
(src1_index, shape1)
};

// Insert swish op
let shape = graph
.graph
.edges_connecting(x, mul1)
.find_map(|e| e.weight().as_data())
.unwrap()
.2;
let swish = graph
.add_op(MetalSwish::<T>::new(dev.clone(), queue.clone()))
.input(x, 0, shape)
.input(src_index, 0, shape)
.finish();

// Create edges to dests
move_outgoing_edge(mul3, swish, &mut graph.graph);
move_outgoing_edge(mul2, swish, &mut graph.graph);

// Remove the old ops
graph.graph.remove_node(mul1);
graph.graph.remove_node(mul2);
graph.graph.remove_node(mul3);
graph.graph.remove_node(neg_one);
graph.graph.remove_node(exp);
graph.graph.remove_node(one);
Expand Down

0 comments on commit fa67608

Please sign in to comment.