From e67d3e65987f06b551373848cb27631620863c75 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Mon, 1 Jan 2024 21:26:04 -0500 Subject: [PATCH] Small changes --- examples/mistral/main.rs | 6 +++--- src/compilers/generic.rs | 29 +++++++++++++++++++++++------ src/compilers/metal/binary.rs | 1 + 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/mistral/main.rs b/examples/mistral/main.rs index 9af62446..0ea2230e 100644 --- a/examples/mistral/main.rs +++ b/examples/mistral/main.rs @@ -16,7 +16,7 @@ type DeviceCompiler = CudaFp16Compiler; type DeviceCompiler = CPUCompiler; fn main() { - let prompt = "[INST]Please write a python implementation of merge sort.[/INST]\n"; + let prompt = "[INST]Here is a python implementation of merge sort:[/INST]\n"; let tokens_to_generate = 300; println!("Creating graph..."); @@ -121,7 +121,7 @@ fn main() { print!( "{}{}", prompt.white().bold(), - decode(&tokenizer, &[output_id]).green() + decode(&tokenizer, &[output_id]).bright_green() ); std::io::stdout().flush().unwrap(); @@ -148,7 +148,7 @@ fn main() { let output_id = sample_index(&decode_logits.data()); decode_logits.drop(); input_ids.push(output_id); - print!("{}", decode(&tokenizer, &[output_id]).green()); + print!("{}", decode(&tokenizer, &[output_id]).bright_green()); std::io::stdout().flush().unwrap(); // Swap caches diff --git a/src/compilers/generic.rs b/src/compilers/generic.rs index 11910131..bd547fb9 100644 --- a/src/compilers/generic.rs +++ b/src/compilers/generic.rs @@ -304,6 +304,7 @@ pub struct RemapDownstream(pub Vec); impl Compiler for RemapDownstream { fn compile(&self, graph: &mut Graph) { + let set = self.0.iter().copied().collect::>(); // Loop through state dict tensors marked as no_delete for mut node in self .0 @@ -324,12 +325,7 @@ impl Compiler for RemapDownstream { .next() .unwrap() .target(); - if graph - .graph - .edges_directed(new_node, Direction::Incoming) - .count() - > 1 - { + if !is_from_set(new_node, graph, &set) { break; } // Remap node to new node @@ -340,6 +336,27 @@ impl Compiler for RemapDownstream { } } +fn is_from_set(node: NodeIndex, graph: &Graph, set: &HashSet) -> bool { + // Reverse dfs upward + let mut stack = vec![node]; + while let Some(node) = stack.pop() { + if !set.contains(&node) { + let mut new_nodes = graph + .graph + .edges_directed(node, Direction::Incoming) + .filter(|e| !e.weight().is_schedule()) + .map(|e| e.source()) + .collect_vec(); + if new_nodes.is_empty() { + // Node isn't from set and doesn't have upstream nodes + return false; + } + stack.append(&mut new_nodes); + } + } + true +} + #[cfg(test)] mod tests { use crate::prelude::*; diff --git a/src/compilers/metal/binary.rs b/src/compilers/metal/binary.rs index c7a2a642..1077b2de 100644 --- a/src/compilers/metal/binary.rs +++ b/src/compilers/metal/binary.rs @@ -608,6 +608,7 @@ impl Compiler for MetalGatherCompiler { move_outgoing_edge(sum_reduce, gather, &mut graph.graph); graph.graph.remove_node(arange); graph.graph.remove_node(ind_copy); + graph.id_remap.retain(|_, v| *v != ind_copy); graph.graph.remove_node(mul); graph.graph.remove_node(sum_reduce); }