From 53dc4dd9df59c0737215187d21233d22f1ec7de7 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Mon, 1 Jan 2024 00:02:06 -0500 Subject: [PATCH] Optimized storage compiler --- src/compilers/metal/fp16/matmul.rs | 2 +- src/compilers/metal/fp16/mean_reduce.rs | 2 +- src/compilers/metal/fp16/mod.rs | 10 ++++---- src/compilers/metal/fp16/rms_norm.rs | 2 +- src/compilers/metal/storage_buffer.rs | 31 ++++++++++++------------- 5 files changed, 22 insertions(+), 25 deletions(-) diff --git a/src/compilers/metal/fp16/matmul.rs b/src/compilers/metal/fp16/matmul.rs index 56c0a2c9..c91fbb62 100644 --- a/src/compilers/metal/fp16/matmul.rs +++ b/src/compilers/metal/fp16/matmul.rs @@ -590,7 +590,7 @@ impl Operator for MetalBatchMatmul2D { } } -#[derive(Default)] +#[derive(Default, Debug)] pub struct MetalMatMulCompiler; impl Compiler for MetalMatMulCompiler { diff --git a/src/compilers/metal/fp16/mean_reduce.rs b/src/compilers/metal/fp16/mean_reduce.rs index 456ae4c6..de3e0cbc 100644 --- a/src/compilers/metal/fp16/mean_reduce.rs +++ b/src/compilers/metal/fp16/mean_reduce.rs @@ -159,7 +159,7 @@ impl Operator for MetalMeanReduce { } /// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler. -#[derive(Default)] +#[derive(Default, Debug)] pub struct MeanReduceCompiler; impl Compiler for MeanReduceCompiler { diff --git a/src/compilers/metal/fp16/mod.rs b/src/compilers/metal/fp16/mod.rs index ad769594..4fad7c52 100644 --- a/src/compilers/metal/fp16/mod.rs +++ b/src/compilers/metal/fp16/mod.rs @@ -1,7 +1,5 @@ use half::f16; -use crate::prelude::TimedCompiler; - mod matmul; mod mean_reduce; mod other; @@ -10,10 +8,10 @@ mod rms_norm; pub type MetalFp16Compiler = ( super::prim::PrimitiveCompiler, ( - TimedCompiler>, - TimedCompiler>, - TimedCompiler>, - TimedCompiler>, + super::binary::MetalSubtractionCompiler, + super::binary::MetalEqualCompiler, + super::other::ARangeCompiler, + super::binary::MetalGatherCompiler, ), other::MetalExpCompiler, matmul::MetalMatMulCompiler, diff --git a/src/compilers/metal/fp16/rms_norm.rs b/src/compilers/metal/fp16/rms_norm.rs index 345b70d1..4ce8eebc 100644 --- a/src/compilers/metal/fp16/rms_norm.rs +++ b/src/compilers/metal/fp16/rms_norm.rs @@ -195,7 +195,7 @@ impl Operator for MetalRMSNorm { } /// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler. -#[derive(Default)] +#[derive(Default, Debug)] pub struct RMSNormCompiler; impl Compiler for RMSNormCompiler { diff --git a/src/compilers/metal/storage_buffer.rs b/src/compilers/metal/storage_buffer.rs index ad32f7b3..62310e92 100644 --- a/src/compilers/metal/storage_buffer.rs +++ b/src/compilers/metal/storage_buffer.rs @@ -1,6 +1,6 @@ use std::{ cell::UnsafeCell, - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, fmt::Debug, sync::Arc, }; @@ -35,13 +35,14 @@ impl Compiler for StorageBufferCompiler { BTreeSet, ), > = HashMap::new(); - // Loop through starting nodes in graph - for node in toposort(&graph.graph, None).unwrap() { + let toposort = toposort(&graph.graph, None).unwrap(); + // Loop through nodes in graph + for node in &toposort { // Run through parents to build new tenative set and clear set let (mut tenative_sets, mut clear_set) = (BTreeMap::default(), BTreeSet::default()); for parent in graph .graph - .edges_directed(node, Direction::Incoming) + .edges_directed(*node, Direction::Incoming) .filter(|e| !e.weight().is_schedule()) .map(|e| e.source()) { @@ -77,7 +78,7 @@ impl Compiler for StorageBufferCompiler { clear_set.extend(parent_clear_set); } } - first_pass.insert(node, (tenative_sets, clear_set)); + first_pass.insert(*node, (tenative_sets, clear_set)); } // Second pass - assign buffers @@ -112,22 +113,23 @@ impl Compiler for StorageBufferCompiler { // Loop through nodes in graph let mut buffers = vec![]; let mut buffer_map = HashMap::new(); - for node in toposort(&graph.graph, None).unwrap() { + let mut used = HashSet::::new(); + for node in &toposort { if graph.no_delete.contains(&node) { continue; } let Some(Ok(wrapper)) = graph .graph - .node_weight(node) + .node_weight(*node) .unwrap() .custom("metal") .map(|e| e.downcast::()) else { continue; }; - buffer_map.insert(node, (vec![], vec![])); + buffer_map.insert(*node, (vec![], vec![])); let input_shapes = graph - .get_sources(node) + .get_sources(*node) .into_iter() .map(|(_, _, i)| i) .collect::>(); @@ -137,6 +139,7 @@ impl Compiler for StorageBufferCompiler { if let Some((buffer_index, source_node, _)) = first_pass[&node] .1 .iter() + .filter(|i| !used.contains(i)) .filter(|i| available_buffers.contains_key(i)) .flat_map(|i| { available_buffers[i] @@ -151,9 +154,7 @@ impl Compiler for StorageBufferCompiler { let buffer = buffer_map.get(&source_node).unwrap().0[buffer_index]; buffer_map.get_mut(&node).unwrap().0.push(buffer); // Remove this buffer from first_pass so it can't be used again - for (_, v) in &mut first_pass { - v.1.retain(|i| *i != source_node); - } + used.insert(source_node); } else { // Allocate new buffer buffer_map.get_mut(&node).unwrap().0.push(buffers.len()); @@ -166,6 +167,7 @@ impl Compiler for StorageBufferCompiler { if let Some((buffer_index, source_node, _)) = first_pass[&node] .1 .iter() + .filter(|i| !used.contains(i)) .filter(|i| available_buffers.contains_key(i)) .flat_map(|i| { available_buffers[i] @@ -179,9 +181,7 @@ impl Compiler for StorageBufferCompiler { { let buffer = buffer_map.get(&source_node).unwrap().1[buffer_index]; buffer_map.get_mut(&node).unwrap().1.push(buffer); - for (_, v) in &mut first_pass { - v.1.retain(|i| *i != source_node); - } + used.insert(source_node); } else { // Allocate new buffer buffer_map.get_mut(&node).unwrap().1.push(buffers.len()); @@ -283,7 +283,6 @@ impl Operator for AllocateMetalBuffers { for (size, buffer) in self.buffer_sizes.iter().zip(buffers) { let size = size.exec(dyn_map).unwrap() as u64; if buffer.length() != size { - // println!("reallocing {}", size); buffer.set_purgeable_state(metal_rs::MTLPurgeableState::Empty); *buffer = self .dev