diff --git a/src/compilers/generic.rs b/src/compilers/generic.rs index ea270bf7..11910131 100644 --- a/src/compilers/generic.rs +++ b/src/compilers/generic.rs @@ -21,7 +21,7 @@ pub type GenericCompiler = (PreGenericCompiler, Inner, PostGeneri pub type PostGenericCompiler = ( UnarySequentialElimination, // RemoveUnusedNodes, // Broken right now, unclear why - // CSE, // This breaks compilers::metal::fp16::tests::test_encoder_block. I think it needs to take edge weights into account? + CSE, ); pub type PreGenericCompiler = (RemoveSingleReductions,); @@ -97,13 +97,16 @@ impl Compiler for CSE { fn compile(&self, graph: &mut Graph) { // Look for nodes that have the exact same srcs // Loop cause I'm lazy - loop { - let mut eliminated = false; + let mut eliminated = true; + while eliminated { + eliminated = false; let mut srcs_set = HashMap::new(); for node in graph.graph.node_indices().collect_vec() { - let mut srcs = graph + let srcs = graph .graph .edges_directed(node, petgraph::Direction::Incoming) + .filter(|e| !e.weight().is_schedule()) + .sorted_by_key(|e| e.weight().as_data().unwrap().0) .map(|e| e.source()) .collect_vec(); @@ -118,9 +121,6 @@ impl Compiler for CSE { continue; } - // If order doesn't matter, make sure different ordered srcs match by sorting - srcs.sort(); - if let Some(other_node) = srcs_set.get(&srcs) { let a = graph.graph.node_weight(node).unwrap(); let b = graph.graph.node_weight(*other_node).unwrap(); @@ -155,10 +155,6 @@ impl Compiler for CSE { } srcs_set.insert(srcs, node); } - - if !eliminated { - break; - } } } } diff --git a/src/compilers/metal/binary.rs b/src/compilers/metal/binary.rs index 51560dd5..c7a2a642 100644 --- a/src/compilers/metal/binary.rs +++ b/src/compilers/metal/binary.rs @@ -412,7 +412,7 @@ impl Compiler for MetalEqualCompiler { .map(|e| e.source()) .collect::>(); let (a, b) = (inputs[0], inputs[1]); - if check_no_delete(graph, &[a, b, less_than1, less_than2, add, one, sub]) { + if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) { continue; } let a_edge = graph diff --git a/src/compilers/metal/fp16/tests.rs b/src/compilers/metal/fp16/tests.rs index c2ab7b49..a38228ea 100644 --- a/src/compilers/metal/fp16/tests.rs +++ b/src/compilers/metal/fp16/tests.rs @@ -857,10 +857,13 @@ fn test_common_buffer() { fn test_embedding() { let mut cx = Graph::new(); let batch = cx - .tensor::>() + .named_tensor::>("Batch") .set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]) .keep(); - let a = cx.tensor::>().set(vec![1.0, 0.0, 1.0]).keep(); + let a = cx + .named_tensor::>("Single") + .set(vec![1.0, 0.0, 1.0]) + .keep(); let model: crate::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx); model diff --git a/src/compilers/metal/prim.rs b/src/compilers/metal/prim.rs index 31375723..be6be79c 100644 --- a/src/compilers/metal/prim.rs +++ b/src/compilers/metal/prim.rs @@ -126,7 +126,7 @@ impl MetalContiguous { #include using namespace metal; kernel void mkernel(device {} *inp [[buffer(0)]], device {} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{}) {{ - if (idx < n_elements && ({valid_exp} != 0)) {{ + if (idx < n_elements && {valid_exp} != 0) {{ out[idx] = inp[{idx_exp}]; }} }} diff --git a/src/hl_ops/movement.rs b/src/hl_ops/movement.rs index b54606e4..74ce1a26 100644 --- a/src/hl_ops/movement.rs +++ b/src/hl_ops/movement.rs @@ -54,7 +54,7 @@ impl GraphTensor { /// Dynamically reshape with annotations for the shape tracker pub fn dyn_reshape(self, shape: Vec) -> GraphTensor { - let id = if !self.shape.is_contiguous() { + let id = if !self.shape.indexes.iter().enumerate().all(|(a, b)| a == *b) { // Insert contiguous call self.graph() .add_op(op::Contiguous) diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index aa2f4466..e5f67ee0 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -35,7 +35,7 @@ impl Module> type Output = GraphTensor<(S, Const)>; fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output { - , S)>>>::forward(self, input.expand()).max_reduce() + self.weight.gather(input) } }