Skip to content

Commit

Permalink
Fixed cse
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 1, 2024
1 parent 38acdf3 commit 6d9f917
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 17 deletions.
18 changes: 7 additions & 11 deletions src/compilers/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub type GenericCompiler<Inner = ((),)> = (PreGenericCompiler, Inner, PostGeneri
pub type PostGenericCompiler = (
UnarySequentialElimination,
// RemoveUnusedNodes, // Broken right now, unclear why
// CSE, // This breaks compilers::metal::fp16::tests::test_encoder_block. I think it needs to take edge weights into account?
CSE,
);

pub type PreGenericCompiler = (RemoveSingleReductions,);
Expand Down Expand Up @@ -97,13 +97,16 @@ impl Compiler for CSE {
fn compile(&self, graph: &mut Graph) {
// Look for nodes that have the exact same srcs
// Loop cause I'm lazy
loop {
let mut eliminated = false;
let mut eliminated = true;
while eliminated {
eliminated = false;
let mut srcs_set = HashMap::new();
for node in graph.graph.node_indices().collect_vec() {
let mut srcs = graph
let srcs = graph
.graph
.edges_directed(node, petgraph::Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.sorted_by_key(|e| e.weight().as_data().unwrap().0)
.map(|e| e.source())
.collect_vec();

Expand All @@ -118,9 +121,6 @@ impl Compiler for CSE {
continue;
}

// If order doesn't matter, make sure different ordered srcs match by sorting
srcs.sort();

if let Some(other_node) = srcs_set.get(&srcs) {
let a = graph.graph.node_weight(node).unwrap();
let b = graph.graph.node_weight(*other_node).unwrap();
Expand Down Expand Up @@ -155,10 +155,6 @@ impl Compiler for CSE {
}
srcs_set.insert(srcs, node);
}

if !eliminated {
break;
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ impl<T: MetalFloat> Compiler for MetalEqualCompiler<T> {
.map(|e| e.source())
.collect::<Vec<_>>();
let (a, b) = (inputs[0], inputs[1]);
if check_no_delete(graph, &[a, b, less_than1, less_than2, add, one, sub]) {
if check_no_delete(graph, &[less_than1, less_than2, add, one, sub]) {
continue;
}
let a_edge = graph
Expand Down
7 changes: 5 additions & 2 deletions src/compilers/metal/fp16/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,13 @@ fn test_common_buffer() {
fn test_embedding() {
let mut cx = Graph::new();
let batch = cx
.tensor::<R2<2, 3>>()
.named_tensor::<R2<2, 3>>("Batch")
.set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0])
.keep();
let a = cx.tensor::<R1<3>>().set(vec![1.0, 0.0, 1.0]).keep();
let a = cx
.named_tensor::<R1<3>>("Single")
.set(vec![1.0, 0.0, 1.0])
.keep();

let model: crate::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx);
model
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ impl<T: MetalFloat> MetalContiguous<T> {
#include <metal_stdlib>
using namespace metal;
kernel void mkernel(device {} *inp [[buffer(0)]], device {} *out [[buffer(1)]], device int& n_elements [[buffer(2)]], uint idx [[thread_position_in_grid]]{}) {{
if (idx < n_elements && ({valid_exp} != 0)) {{
if (idx < n_elements && {valid_exp} != 0) {{
out[idx] = inp[{idx_exp}];
}}
}}
Expand Down
2 changes: 1 addition & 1 deletion src/hl_ops/movement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl<S: Shape> GraphTensor<S> {

/// Dynamically reshape with annotations for the shape tracker
pub fn dyn_reshape<N: Shape>(self, shape: Vec<Expression>) -> GraphTensor<N> {
let id = if !self.shape.is_contiguous() {
let id = if !self.shape.indexes.iter().enumerate().all(|(a, b)| a == *b) {
// Insert contiguous call
self.graph()
.add_op(op::Contiguous)
Expand Down
2 changes: 1 addition & 1 deletion src/nn/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<S: Dimension, const N: usize, const DIM: usize> Module<GraphTensor<(S,)>>
type Output = GraphTensor<(S, Const<DIM>)>;

fn forward(&self, input: GraphTensor<(S,)>) -> Self::Output {
<Self as Module<GraphTensor<(Const<1>, S)>>>::forward(self, input.expand()).max_reduce()
self.weight.gather(input)
}
}

Expand Down

0 comments on commit 6d9f917

Please sign in to comment.