Skip to content

Commit

Permalink
Added symbolic slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 26, 2023
1 parent 9e3bea8 commit 422fd32
Show file tree
Hide file tree
Showing 9 changed files with 141 additions and 55 deletions.
5 changes: 1 addition & 4 deletions examples/llama/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ impl DfdxDeferredLoader {

impl Loader for DfdxDeferredLoader {
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) {
let mut serializer = Serializer::default();
model.serialize(&mut serializer);

for (s, n) in serializer.state {
for (s, n) in state_dict(model) {
let Some(n_elements) = graph
.graph
.edges_directed(n, petgraph::Direction::Outgoing)
Expand Down
6 changes: 3 additions & 3 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ fn main() {
input.insert(0, 1); // Start token

println!("Creating Graphs...");
let mut cx1 = Graph::new();
let mut cx2 = Graph::new();
let mut cx1 = Graph::new(); // Prompt processing graph
let mut cx2 = Graph::new(); // Token generation graph
let model = Model::initialize(&mut cx1);
let inp = cx1.named_tensor::<(Const<1>, Dyn<'s'>)>("Input").set_dyn(
input.iter().map(|i| *i as f32).collect::<Vec<f32>>(),
Expand Down Expand Up @@ -133,7 +133,7 @@ fn main() {
transfer_weights(&model, &mut cx1, &kv_model, &mut cx2);

loop {
single_inp.set_dyn(vec![*input.last().unwrap() as f32], vec![1, 1]);
single_inp.set(vec![*input.last().unwrap() as f32]);
cx2.set_dyn_dim('p', input.len() - 1);
cx2.set_dyn_dim('t', input.len());

Expand Down
10 changes: 7 additions & 3 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use half::f16;
use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
prelude::*,
shape::symbolic::BigExpression,
shape::symbolic::{BigExpression, Expression},
};

// Full LLaMa model implementation, heavily based off of https://github.com/coreylowman/llama-dfdx/blob/main/src/modeling.rs
Expand Down Expand Up @@ -113,8 +113,12 @@ impl<const HEAD_DIM: usize, const HEAD_DIM_OVER_2: usize>
fn rotate_half<Batch: Dimension, NumHeads: Dimension, Seq: Dimension>(
x: GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)>,
) -> GraphTensor<(Batch, NumHeads, Seq, Const<HEAD_DIM>)> {
let x1 = x.slice((.., .., .., ..HEAD_DIM_OVER_2)).contiguous();
let x2 = x.slice((.., .., .., HEAD_DIM_OVER_2..)).contiguous();
let x1 = x
.slice((.., .., .., ..Expression::from(HEAD_DIM_OVER_2)))
.contiguous();
let x2 = x
.slice((.., .., .., Expression::from(HEAD_DIM_OVER_2)..))
.contiguous();
(-x2).concat_along::<(Batch, NumHeads, Seq, Const<HEAD_DIM>), Axis<3>, _>(x1)
}
}
Expand Down
15 changes: 12 additions & 3 deletions src/compilers/metal/fp16/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,11 @@ fn test_slice() {
let data = random_vec(256);
let mut cx = Graph::new();
let a = cx.tensor::<R1<256>>().set(data.clone());
let c: GraphTensor<R1<20>> = a.slice((..20,)).realize().contiguous().retrieve();
let c: GraphTensor<R1<20>> = a
.slice((..Expression::from(20),))
.realize()
.contiguous()
.retrieve();

cx.compile(MetalFp16Compiler::default());
cx.execute();
Expand Down Expand Up @@ -947,7 +951,8 @@ fn test_pad_contig() {
.pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')])
.contiguous()
.retrieve();
let c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = (a.slice((.., ..k)).realize() / 1.0).retrieve();
let c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
(a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve();

cx.compile(MetalFp16Compiler::default());
cx.execute();
Expand All @@ -963,7 +968,11 @@ fn test_movement() {
let mut cx = Graph::new();
let a = cx.tensor::<R1<32>>().set(data.clone());
let b: GraphTensor<R1<42>> = a.pad(&[(0, 10)]).contiguous().retrieve();
let c: GraphTensor<R1<25>> = b.slice((..25,)).realize().contiguous().retrieve();
let c: GraphTensor<R1<25>> = b
.slice((..Expression::from(25),))
.realize()
.contiguous()
.retrieve();

cx.compile(MetalFp16Compiler::default());
cx.execute();
Expand Down
109 changes: 87 additions & 22 deletions src/compilers/metal/storage_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::{
cell::UnsafeCell,
collections::{HashMap, HashSet},
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
fmt::Debug,
sync::Arc,
};

use itertools::Itertools;
use metal_rs::{Buffer, CommandBuffer, CommandQueue, Device, MTLResourceOptions};
use petgraph::{
algo::toposort,
stable_graph::NodeIndex,
visit::{Bfs, EdgeRef},
visit::Bfs,
Direction::{self},
};

Expand All @@ -23,26 +24,13 @@ pub struct StorageBufferCompiler;

impl Compiler for StorageBufferCompiler {
fn compile(&self, graph: &mut Graph) {
let is_metal: HashSet<NodeIndex> = graph
.graph
.node_indices()
.filter(|i| {
graph
.graph
.node_weight(*i)
.unwrap()
.custom("metal")
.is_some()
})
.collect();
// First pass
// First pass - get clear sets for each node
#[allow(clippy::type_complexity)]
let mut first_pass: HashMap<
NodeIndex,
(Vec<(NodeIndex, Vec<NodeIndex>)>, Vec<NodeIndex>),
(BTreeMap<NodeIndex, Vec<NodeIndex>>, BTreeSet<NodeIndex>),
> = HashMap::new();
// Loop through starting nodes in graph
for node in graph
let starting_nodes = graph
.graph
.node_indices()
.filter(|n| {
Expand All @@ -52,13 +40,90 @@ impl Compiler for StorageBufferCompiler {
.count()
== 0
})
.collect_vec()
{
.collect_vec();
// Loop through starting nodes in graph
for node in &starting_nodes {
// Breadth first search from starting nodes
let mut bfs = Bfs::new(&graph.graph, node);
let mut bfs = Bfs::new(&graph.graph, *node);
while let Some(node) = bfs.next(&graph.graph) {
todo!();
// Run through parents to build new tenative set and clear set
let (mut tenative_set, mut clear_set) = (BTreeMap::default(), BTreeSet::default());
for parent in graph.graph.neighbors_directed(node, Direction::Incoming) {
if let Some((parent_tenative_set, parent_clear_set)) = first_pass.get(&parent) {
let new_tenative_set = parent_tenative_set
.iter()
.map(|(n, c)| {
let mut c = c.clone();
c.retain(|n| *n != parent);
(*n, c)
})
.collect::<BTreeMap<_, _>>();
tenative_set.extend(new_tenative_set);
clear_set.extend(
tenative_set
.iter()
.filter(|(_, v)| v.is_empty())
.map(|(n, _)| *n),
);
tenative_set.retain(|_, v| !v.is_empty());
clear_set.extend(parent_clear_set);
}
}
first_pass.insert(node, (tenative_set, clear_set));
}
}

// Second pass - assign buffers
let available_buffers = graph
.graph
.node_indices()
.map(|n| {
let input_shapes = graph
.get_sources(n)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
let output_buffers = graph
.graph
.node_weight(n)
.unwrap()
.custom("metal")
.unwrap()
.downcast_ref::<MetalKernelWrapper>()
.unwrap()
.0
.output_buffer_sizes(&input_shapes);
(n, output_buffers)
})
.collect::<HashMap<_, _>>();
// Loop through starting nodes in graph
for node in toposort(&graph.graph, None).unwrap() {
let Some(Some(wrapper)) = graph
.graph
.node_weight(node)
.unwrap()
.custom("metal")
.map(|e| e.downcast_ref::<MetalKernelWrapper>().cloned())
else {
continue;
};
let input_shapes = graph
.get_sources(node)
.into_iter()
.map(|(_, _, i)| i)
.collect::<Vec<_>>();
// Assign output buffers
for required_buffer in wrapper.0.output_buffer_sizes(&input_shapes) {
// Find an applicable buffer
if let Some((source_node, applicable_buffer)) = first_pass[&node]
.1
.iter()
.flat_map(|i| available_buffers[i].iter().cloned().map(|b| (*i, b)))
.find(|(_, size)| *size == required_buffer)
{}
}
// Assing intermediate buffers
for required_buffer in wrapper.0.intermediate_buffer_sizes(&input_shapes) {}
}
}
}
Expand Down
7 changes: 5 additions & 2 deletions src/core/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ impl Operator for MaxReduce {

#[cfg(test)]
mod tests {
use crate::{prelude::*, tests::assert_close};
use crate::{
prelude::{symbolic::Expression, *},
tests::assert_close,
};
use dfdx::prelude::*;
use itertools::Itertools;

Expand Down Expand Up @@ -617,7 +620,7 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor::<R2<2, 3>>();
a.set(vec![1., 2., 3., 1., 2., 3.]);
let b = a.slice((1.., ..));
let b = a.slice((Expression::from(1).., ..));
b.retrieve();
cx.execute();

Expand Down
30 changes: 15 additions & 15 deletions src/core/shape/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl SliceOfShape<R0> for () {
}
}

impl<A: Dimension, R: RangeBounds<usize> + RangeToDim<A>> SliceOfShape<(A,)> for (R,) {
impl<A: Dimension, R: RangeBounds<Expression> + RangeToDim<A>> SliceOfShape<(A,)> for (R,) {
type OutputShape = (R::Dimension,);
fn to_range_vec(&self) -> Vec<(Expression, Expression)> {
vec![(
Expand All @@ -81,8 +81,8 @@ impl<A: Dimension, R: RangeBounds<usize> + RangeToDim<A>> SliceOfShape<(A,)> for
impl<
A: Dimension,
B: Dimension,
R1: RangeBounds<usize> + RangeToDim<A>,
R2: RangeBounds<usize> + RangeToDim<B>,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
> SliceOfShape<(A, B)> for (R1, R2)
{
type OutputShape = (R1::Dimension, R2::Dimension);
Expand All @@ -104,9 +104,9 @@ impl<
A: Dimension,
B: Dimension,
C: Dimension,
R1: RangeBounds<usize> + RangeToDim<A>,
R2: RangeBounds<usize> + RangeToDim<B>,
R3: RangeBounds<usize> + RangeToDim<C>,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
> SliceOfShape<(A, B, C)> for (R1, R2, R3)
{
type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension);
Expand All @@ -133,10 +133,10 @@ impl<
B: Dimension,
C: Dimension,
D: Dimension,
R1: RangeBounds<usize> + RangeToDim<A>,
R2: RangeBounds<usize> + RangeToDim<B>,
R3: RangeBounds<usize> + RangeToDim<C>,
R4: RangeBounds<usize> + RangeToDim<C>,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
R4: RangeBounds<Expression> + RangeToDim<C>,
> SliceOfShape<(A, B, C, D)> for (R1, R2, R3, R4)
{
type OutputShape = (R1::Dimension, R2::Dimension, R3::Dimension, R4::Dimension);
Expand Down Expand Up @@ -168,11 +168,11 @@ impl<
C: Dimension,
D: Dimension,
E: Dimension,
R1: RangeBounds<usize> + RangeToDim<A>,
R2: RangeBounds<usize> + RangeToDim<B>,
R3: RangeBounds<usize> + RangeToDim<C>,
R4: RangeBounds<usize> + RangeToDim<C>,
R5: RangeBounds<usize> + RangeToDim<C>,
R1: RangeBounds<Expression> + RangeToDim<A>,
R2: RangeBounds<Expression> + RangeToDim<B>,
R3: RangeBounds<Expression> + RangeToDim<C>,
R4: RangeBounds<Expression> + RangeToDim<C>,
R5: RangeBounds<Expression> + RangeToDim<C>,
> SliceOfShape<(A, B, C, D, E)> for (R1, R2, R3, R4, R5)
{
type OutputShape = (
Expand Down
6 changes: 6 additions & 0 deletions src/core/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ impl From<Expression> for BigExpression {
}
}

impl From<&Expression> for Expression {
fn from(value: &Expression) -> Self {
*value
}
}

impl From<BigExpression> for Expression {
fn from(value: BigExpression) -> Self {
let mut terms = ArrayVec::new();
Expand Down
8 changes: 5 additions & 3 deletions src/hl_ops/movement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ mod tests {
tensor_ops::{RealizeTo, TryConcatAlong},
};

use crate::prelude::symbolic::Expression;

crate::test_imports!();

#[test]
Expand Down Expand Up @@ -317,7 +319,7 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor::<R2<3, 2>>();
a.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
let b = a.slice((.., ..1)).realize::<R2<3, 1>>();
let b = a.slice((.., ..Expression::from(1))).realize::<R2<3, 1>>();
b.retrieve();
cx.execute();

Expand Down Expand Up @@ -459,8 +461,8 @@ mod tests {
let mut cx = Graph::new();
let a = cx.tensor::<R2<3, 2>>();
a.set(vec![1.4325, 2.492428, 3.127365, 33.2834, 4.18734, 23.854]);
let x1 = a.slice((.., ..1)).contiguous();
let x2 = a.slice((.., 1..)).contiguous();
let x1 = a.slice((.., ..Expression::from(1))).contiguous();
let x2 = a.slice((.., Expression::from(1)..)).contiguous();
let c = (-x2).concat_along::<R2<3, 2>, LAxis<1>, _>(x1);
c.retrieve();
cx.execute();
Expand Down

0 comments on commit 422fd32

Please sign in to comment.