Skip to content

Commit

Permalink
Optimized storage compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 1, 2024
1 parent 5262e32 commit 53dc4dd
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ impl Operator for MetalBatchMatmul2D {
}
}

#[derive(Default)]
#[derive(Default, Debug)]
pub struct MetalMatMulCompiler;

impl Compiler for MetalMatMulCompiler {
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mean_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl Operator for MetalMeanReduce {
}

/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler.
#[derive(Default)]
#[derive(Default, Debug)]
pub struct MeanReduceCompiler;

impl Compiler for MeanReduceCompiler {
Expand Down
10 changes: 4 additions & 6 deletions src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use half::f16;

use crate::prelude::TimedCompiler;

mod matmul;
mod mean_reduce;
mod other;
Expand All @@ -10,10 +8,10 @@ mod rms_norm;
pub type MetalFp16Compiler = (
super::prim::PrimitiveCompiler<f16>,
(
TimedCompiler<super::binary::MetalSubtractionCompiler<f16>>,
TimedCompiler<super::binary::MetalEqualCompiler<f16>>,
TimedCompiler<super::other::ARangeCompiler<f16>>,
TimedCompiler<super::binary::MetalGatherCompiler<f16>>,
super::binary::MetalSubtractionCompiler<f16>,
super::binary::MetalEqualCompiler<f16>,
super::other::ARangeCompiler<f16>,
super::binary::MetalGatherCompiler<f16>,
),
other::MetalExpCompiler,
matmul::MetalMatMulCompiler,
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl Operator for MetalRMSNorm {
}

/// Replace the mean reduce pattern with a special kernel. This is meant to be ran **after** the FakeSumReduceCompiler.
#[derive(Default)]
#[derive(Default, Debug)]
pub struct RMSNormCompiler;

impl Compiler for RMSNormCompiler {
Expand Down
31 changes: 15 additions & 16 deletions src/compilers/metal/storage_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
cell::UnsafeCell,
collections::{BTreeMap, BTreeSet, HashMap},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Debug,
sync::Arc,
};
Expand Down Expand Up @@ -35,13 +35,14 @@ impl Compiler for StorageBufferCompiler {
BTreeSet<NodeIndex>,
),
> = HashMap::new();
// Loop through starting nodes in graph
for node in toposort(&graph.graph, None).unwrap() {
let toposort = toposort(&graph.graph, None).unwrap();
// Loop through nodes in graph
for node in &toposort {
// Run through parents to build new tenative set and clear set
let (mut tenative_sets, mut clear_set) = (BTreeMap::default(), BTreeSet::default());
for parent in graph
.graph
.edges_directed(node, Direction::Incoming)
.edges_directed(*node, Direction::Incoming)
.filter(|e| !e.weight().is_schedule())
.map(|e| e.source())
{
Expand Down Expand Up @@ -77,7 +78,7 @@ impl Compiler for StorageBufferCompiler {
clear_set.extend(parent_clear_set);
}
}
first_pass.insert(node, (tenative_sets, clear_set));
first_pass.insert(*node, (tenative_sets, clear_set));
}

// Second pass - assign buffers
Expand Down Expand Up @@ -112,22 +113,23 @@ impl Compiler for StorageBufferCompiler {
// Loop through nodes in graph
let mut buffers = vec![];
let mut buffer_map = HashMap::new();
for node in toposort(&graph.graph, None).unwrap() {
let mut used = HashSet::<NodeIndex>::new();
for node in &toposort {
if graph.no_delete.contains(&node) {
continue;
}
let Some(Ok(wrapper)) = graph
.graph
.node_weight(node)
.node_weight(*node)
.unwrap()
.custom("metal")
.map(|e| e.downcast::<MetalKernelWrapper>())
else {
continue;
};
buffer_map.insert(node, (vec![], vec![]));
buffer_map.insert(*node, (vec![], vec![]));
let input_shapes = graph
.get_sources(node)
.get_sources(*node)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
Expand All @@ -137,6 +139,7 @@ impl Compiler for StorageBufferCompiler {
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
available_buffers[i]
Expand All @@ -151,9 +154,7 @@ impl Compiler for StorageBufferCompiler {
let buffer = buffer_map.get(&source_node).unwrap().0[buffer_index];
buffer_map.get_mut(&node).unwrap().0.push(buffer);
// Remove this buffer from first_pass so it can't be used again
for (_, v) in &mut first_pass {
v.1.retain(|i| *i != source_node);
}
used.insert(source_node);
} else {
// Allocate new buffer
buffer_map.get_mut(&node).unwrap().0.push(buffers.len());
Expand All @@ -166,6 +167,7 @@ impl Compiler for StorageBufferCompiler {
if let Some((buffer_index, source_node, _)) = first_pass[&node]
.1
.iter()
.filter(|i| !used.contains(i))
.filter(|i| available_buffers.contains_key(i))
.flat_map(|i| {
available_buffers[i]
Expand All @@ -179,9 +181,7 @@ impl Compiler for StorageBufferCompiler {
{
let buffer = buffer_map.get(&source_node).unwrap().1[buffer_index];
buffer_map.get_mut(&node).unwrap().1.push(buffer);
for (_, v) in &mut first_pass {
v.1.retain(|i| *i != source_node);
}
used.insert(source_node);
} else {
// Allocate new buffer
buffer_map.get_mut(&node).unwrap().1.push(buffers.len());
Expand Down Expand Up @@ -283,7 +283,6 @@ impl Operator for AllocateMetalBuffers {
for (size, buffer) in self.buffer_sizes.iter().zip(buffers) {
let size = size.exec(dyn_map).unwrap() as u64;
if buffer.length() != size {
// println!("reallocing {}", size);
buffer.set_purgeable_state(metal_rs::MTLPurgeableState::Empty);
*buffer = self
.dev
Expand Down

0 comments on commit 53dc4dd

Please sign in to comment.