diff --git a/crates/luminal_cuda/Cargo.toml b/crates/luminal_cuda/Cargo.toml index 964f1fad..3a0d263f 100644 --- a/crates/luminal_cuda/Cargo.toml +++ b/crates/luminal_cuda/Cargo.toml @@ -9,17 +9,19 @@ license = "MIT OR Apache-2.0" [dependencies] luminal = { path = "../.." } -cudarc = { version="0.11.1", features = [ +cudarc = { version = "0.11.1", features = [ "f16", "cuda-version-from-build-system", -]} +] } itertools = "0.12.1" rustc-hash = "1.1.0" num-traits = "0.2.18" regex = "1.10.4" +indicatif = "0.17.8" [dev-dependencies] dfdx = { version = "0.13", features = ["f16"] } rand = "0.8.5" paste = "1.0.14" -luminal_nn = {path="../../crates/luminal_nn"} +luminal_nn = { path = "../../crates/luminal_nn" } +candle-core = "0.5.0" diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index 166ee6a1..d34e0ce8 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -376,7 +376,7 @@ impl Compiler for GatherCompiler { .as_data() .unwrap() .2; - let embed_dim = emb_shape.shape().last().unwrap().to_usize().unwrap(); + let embed_dim = emb_shape.dims().last().unwrap().to_usize().unwrap(); let index_shape = graph .edges_connecting(s.get(&indexes), s.get(&ind_copy)) .next() @@ -402,27 +402,21 @@ mod tests { use super::*; luminal::test_imports!(); - type TR0 = GraphTensor; - type TR1 = GraphTensor>; - type TR2 = GraphTensor>; - #[test] fn test_gather_compiler_r0() { const CLASSES: usize = 2; const TARGET: usize = 1; let mut cx = Graph::new(); - let mut input: TR0 = cx.tensor(); - let embedder: TR2 = cx.tensor(); + let mut input = cx.tensor(()); + let embedder = cx.tensor((CLASSES, TARGET)); - let input_one_hot: TR1 = input + let input_one_hot = input .graph() - .arange::>() - .equals(input.expand()); - let input_embedding: TR1 = (input_one_hot.expand::, _>() - * embedder) - .sum_reduce::<_, LAxis<0>>(); - let mut loss: TR0 = input_embedding.sum_reduce(); + .arange(CLASSES) + .equals(input.expand(0, CLASSES)); + let input_embedding = (input_one_hot.expand(1, TARGET) * embedder).sum_reduce(0); + let mut loss = input_embedding.sum_reduce(0); let mut weights = vec![embedder.id]; cx.compile( @@ -437,18 +431,17 @@ mod tests { const TARGET: usize = 1; let mut cx = Graph::new(); - let mut input: TR1<1> = cx.tensor(); - let embedder: TR2 = cx.tensor(); + let mut input = cx.tensor(1); + let embedder = cx.tensor((CLASSES, TARGET)); - let input_one_hot: TR2<1, CLASSES> = input + let input_one_hot = input .graph() - .arange::>() - .expand::, _>() - .equals(input.expand()); - let input_embedding: TR2<1, TARGET> = (input_one_hot.expand::, _>() - * embedder.expand()) - .sum_reduce::<_, LAxis<1>>(); - let mut loss: TR0 = input_embedding.sum_reduce(); + .arange(CLASSES) + .expand(0, 1) + .equals(input.expand(1, CLASSES)); + let input_embedding = + (input_one_hot.expand(2, TARGET) * embedder.expand(0, 1)).sum_reduce(1); + let mut loss = input_embedding.sum_reduce(0); let mut weights = vec![embedder.id]; cx.compile( diff --git a/crates/luminal_cuda/src/elementwise_fusion.rs b/crates/luminal_cuda/src/elementwise_fusion.rs index 7a818d13..32962c4d 100644 --- a/crates/luminal_cuda/src/elementwise_fusion.rs +++ b/crates/luminal_cuda/src/elementwise_fusion.rs @@ -1,4 +1,5 @@ use cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; +use indicatif::{ProgressBar, ProgressStyle}; use regex::Regex; use rustc_hash::{FxHashMap, FxHashSet}; use std::{any::Any, fmt::Debug, iter::once, marker::PhantomData, mem::size_of, sync::Arc}; @@ -78,6 +79,7 @@ impl Compiler for ElementwiseFusionCompiler { } let mut intermediate_regexes = FxHashMap::default(); let mut input_regexes = FxHashMap::default(); + let mut n_fused_ops = 0; while matched { matched = false; for edge in graph.edge_indices().collect::>() { @@ -113,7 +115,7 @@ impl Compiler for ElementwiseFusionCompiler { let mut subexpressions_b = graph .try_get_op::>(b) .map(|o| o.subexpressions.clone()) - .unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(&[]))]); + .unwrap_or_else(|| vec![(expression_b.clone(), ShapeTracker::new(()))]); let a_to_b_indexes = graph .edges_connecting(a, b) .map(|e| e.weight().as_data().unwrap().0 as usize) @@ -133,7 +135,7 @@ impl Compiler for ElementwiseFusionCompiler { let mut subexpressions_a = graph .try_get_op::>(a) .map(|o| o.subexpressions.clone()) - .unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(&[]))]); + .unwrap_or_else(|| vec![(expression_a.clone(), ShapeTracker::new(()))]); subexpressions_a.last_mut().unwrap().1 = connecting_shape; // Re-reference b intermediates for i in (0..subexpressions_b.len()).rev() { @@ -263,11 +265,25 @@ impl Compiler for ElementwiseFusionCompiler { remap(a, new_op, &mut ids, graph); } remap(b, new_op, &mut ids, graph); + n_fused_ops += 1; } } // Compile all the kernels we placed let type_name = T::type_name(); let intermediate_match = Regex::new(r"intermediate(\d+)([^0-9]|$)").unwrap(); + let mut bar = None; + if debug() { + println!("Fusing {n_fused_ops} ops into {} ops...", fused_ops.len()); + let b = ProgressBar::new(fused_ops.len() as u64); + b.set_style( + ProgressStyle::with_template( + "[{elapsed_precise}] [{bar:40.bright.blue/white}] {pos:>7}/{len:7}", + ) + .unwrap() + .progress_chars("##-"), + ); + bar = Some(b); + }; for fused_op in fused_ops { let inputs = graph .edges_directed(fused_op, Direction::Incoming) @@ -337,7 +353,7 @@ impl Compiler for ElementwiseFusionCompiler { s.iter() .rev() .take(s.len() - 1) - .fold(BigExpression::from('z'), |acc, inp| { + .fold(Expression::from('z'), |acc, inp| { inp.index_expression().substitute('z', acc) }) }) @@ -394,7 +410,7 @@ impl Compiler for ElementwiseFusionCompiler { .iter() .rev() .fold( - (BigExpression::from(true), BigExpression::from('z')), + (Expression::from(true), Expression::from('z')), |(_, ind_acc), inp| { ( inp.valid_expression().substitute('z', ind_acc.clone()), @@ -437,6 +453,13 @@ extern \"C\" __global__ void kernel({} {type_name}* out, const int n_elements{re ); op.kernel = Some(compile_and_load_kernel(kernel, &device)); op.dyn_chars = dyn_chars; + + if let Some(bar) = &bar { + bar.inc(1); + } + } + if let Some(bar) = bar { + bar.finish(); } } } @@ -448,7 +471,7 @@ pub struct FusedElementwiseOp { dyn_chars: Vec, subexpressions: Vec<(String, ShapeTracker)>, device: Arc, - output_buffer_sizes: Vec, + output_buffer_sizes: Vec, _phantom: PhantomData, } impl Debug for FusedElementwiseOp { @@ -501,7 +524,6 @@ mod tests { }; use luminal_nn::*; use rand::{rngs::StdRng, SeedableRng}; - use std::{marker::PhantomData, ops::Div}; use crate::CudaCompiler; @@ -509,7 +531,7 @@ mod tests { fn test_fusion_simple() { let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); - let inp = cx.tensor::>().set(random_vec_rng(10, &mut rng)); + let inp = cx.tensor(5).set(random_vec_rng(10, &mut rng)); let mut out = inp.exp2().cos().sqrt().retrieve(); cx.execute(); @@ -525,8 +547,8 @@ mod tests { fn test_fusion_binary() { let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); - let a = cx.tensor::>().set(random_vec_rng(10, &mut rng)); - let b = cx.tensor::>().set(random_vec_rng(10, &mut rng)); + let a = cx.tensor(5).set(random_vec_rng(10, &mut rng)); + let b = cx.tensor(5).set(random_vec_rng(10, &mut rng)); let mut out = (a.exp2() + b.cos()).retrieve(); cx.execute(); @@ -542,9 +564,9 @@ mod tests { #[test] fn test_fusion_subexpression_complex() { let mut cx = Graph::new(); - let a = cx.named_tensor::>("a").set(random_vec(10)).keep(); - let b = cx.named_tensor::>("b").set(random_vec(10)).keep(); - let d = cx.named_tensor::>("d").set(random_vec(10)).keep(); + let a = cx.named_tensor("a", 10).set(random_vec(10)).keep(); + let b = cx.named_tensor("b", 10).set(random_vec(10)).keep(); + let d = cx.named_tensor("d", 10).set(random_vec(10)).keep(); let mut out = ((a.exp2() - b.sin()).sin() * 3.4).less_than(d).retrieve(); cx.execute(); @@ -562,12 +584,11 @@ mod tests { let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); let inp = random_vec_rng(10, &mut rng); - let a = cx.named_tensor::>("a").set(inp); + let a = cx.named_tensor("a", (2, 5)).set(inp); let mut padded = a .slice((..Expression::from(1), ..)) - .realize::>() .cos() - .pad::>(((0, 1), (0, 0))) + .pad(((0, 1), (0, 0))) .exp2() .retrieve(); cx.execute(); @@ -588,7 +609,7 @@ mod tests { let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); let data = random_vec_rng(10, &mut rng); - let a = cx.tensor::>().set(data); + let a = cx.tensor((2, 5)).set(data); let mut out = (a.sqrt().exp() + a.sqrt().sin()).retrieve(); cx.execute(); let unopt_out = out.data(); @@ -605,14 +626,10 @@ mod tests { let mut cx = Graph::new(); const SEQ: usize = 2; const HEAD_DIM: usize = 4; - const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; - let freqs = (cx.arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32); let freqs = 1000000_f32.pow(freqs); - let pos = cx.arange::>() + BigExpression::from(0); - let mut emb = pos - .expand::<(_, Const<1>), _>() - .matmul(freqs.expand()) - .retrieve(); + let pos = cx.arange(SEQ) + Expression::from(0); + let mut emb = pos.expand(1, 1).matmul(freqs.expand(0, 1)).retrieve(); cx.execute(); let unopt_out = emb.data(); @@ -631,27 +648,25 @@ mod tests { const HEAD_DIM: usize = 4; const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; let a = cx - .tensor::>() + .tensor((SEQ, HEAD_DIM)) .set(random_vec_rng(SEQ * HEAD_DIM, &mut rng)) .keep(); let b = cx - .tensor::>() - .set(random_vec_rng(SEQ * HEAD_DIM_OVER_2, &mut rng)) + .tensor((SEQ, HEAD_DIM_OVER_2, 1)) + .set(random_vec_rng(SEQ * (HEAD_DIM) / 2, &mut rng)) .keep(); // Split input into evens and odds - let split = a.reshape::>(); - let x0: GraphTensor> = - split.slice((.., .., ..Expression::from(1))).realize(); - let x1: GraphTensor> = - split.slice((.., .., Expression::from(1)..)).realize(); + let split = a.reshape((SEQ, HEAD_DIM / 2, 2)); + let x0 = split.slice((.., .., ..Expression::from(1))); + let x1 = split.slice((.., .., Expression::from(1)..)); let x0_out = x0 * b - x1 * b.cos(); let x1_out = x0 + x1; // Combine back into output - let mut out: GraphTensor> = x0_out - .concat_along::, Axis<2>, _>(x1_out) - .reshape() + let mut out = x0_out + .concat_along(x1_out, 2) + .reshape((SEQ, HEAD_DIM)) .retrieve(); cx.execute(); @@ -671,34 +686,27 @@ mod tests { const N_HEADS: usize = 8; const SEQ: usize = 2; const HEAD_DIM: usize = 4; - const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; let a = cx - .named_tensor::>("a") + .named_tensor("a", (BATCH, N_HEADS, SEQ, HEAD_DIM)) .set(random_vec_rng(BATCH * N_HEADS * SEQ * HEAD_DIM, &mut rng)) .keep(); - let freqs = (cx.arange::>() * 2.0) / (HEAD_DIM as f32); + let freqs = (cx.arange(HEAD_DIM / 2) * 2.0) / (HEAD_DIM as f32); let freqs = 1000000_f32.pow(freqs); - let pos = cx.arange::>() + BigExpression::from(0); - let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + let pos = cx.arange(SEQ) + 0; + let emb = pos.expand(1, 1).matmul(freqs.expand(0, SEQ)); // Split input into evens and odds - let split = a.reshape::>(); - let x0: GraphTensor> = split - .slice((.., .., .., .., ..Expression::from(1))) - .contiguous() - .realize(); - let x1: GraphTensor> = split - .slice((.., .., .., .., Expression::from(1)..)) - .contiguous() - .realize(); + let split = a.reshape((BATCH, N_HEADS, SEQ, HEAD_DIM / 2, 2)); + let x0 = split.slice((.., .., .., .., ..1)).contiguous(); + let x1 = split.slice((.., .., .., .., 1..)).contiguous(); // Apply sin and cos embeddings - let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); - let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape); + let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape); // Combine back into output - let mut out: GraphTensor> = x0_out - .concat_along::, Axis<4>, _>(x1_out) - .reshape() + let mut out = x0_out + .concat_along(x1_out, 4) + .reshape((BATCH, N_HEADS, SEQ, HEAD_DIM)) .retrieve(); cx.execute(); let unopt_out = out.data(); @@ -720,176 +728,160 @@ mod tests { pub const SEQ_LEN: usize = 65; pub const N_ATTENTION_GROUPS: usize = N_HEADS / N_KV_HEADS; pub const HEAD_DIM: usize = HIDDEN_DIM / N_HEADS; - pub const HEAD_DIM_OVER_2: usize = HEAD_DIM / 2; pub const ATTN_PROJ_DIM: usize = HEAD_DIM * N_KV_HEADS; - pub struct Mlp { - pub gate_proj: PermutedLinear, - pub down_proj: PermutedLinear, - pub up_proj: PermutedLinear, - } + pub type KVCache = (GraphTensor, GraphTensor); - pub type KVCache = ( - GraphTensor<(Batch, Const, Seq, Const)>, - GraphTensor<(Batch, Const, Seq, Const)>, - ); + pub struct Mlp { + pub gate_proj: Linear, // hidden -> intermediate + pub down_proj: Linear, // intermediate -> hidden + pub up_proj: Linear, // hidden -> intermediate + } - impl Module> for Mlp - where - GraphTensor: Matmul, Output = GraphTensor>, - GraphTensor: Matmul, Output = GraphTensor>, - { - type Output = GraphTensor; + impl Module for Mlp { + type Output = GraphTensor; - fn forward(&self, input: GraphTensor) -> Self::Output { + fn forward(&self, input: GraphTensor) -> Self::Output { let gate = self.gate_proj.forward(input).swish(); let up = self.up_proj.forward(input) * gate; self.down_proj.forward(up) } } - impl InitModule for Mlp { - fn initialize(cx: &mut Graph) -> Self { + + impl Mlp { + pub fn new(hidden: usize, intermediate: usize, cx: &mut Graph) -> Self { Self { - gate_proj: InitModule::initialize(cx), - up_proj: InitModule::initialize(cx), - down_proj: InitModule::initialize(cx), + gate_proj: Linear::new_permuted(hidden, intermediate, false, cx), + down_proj: Linear::new_permuted(intermediate, hidden, false, cx), + up_proj: Linear::new_permuted(hidden, intermediate, false, cx), } } } - fn apply_rotary_embeddings_ggml( - input: GraphTensor<(Batch, Const, Seq, Const)>, - prev_seq: BigExpression, - ) -> GraphTensor<(Batch, Const, Seq, Const)> { + + impl SerializeModule for Mlp { + fn serialize(&self, s: &mut Serializer) { + s.module("ffn_gate", &self.gate_proj); + s.module("ffn_up", &self.up_proj); + s.module("ffn_down", &self.down_proj); + } + } + + fn apply_rotary_embeddings_ggml(input: GraphTensor, prev_seq: Expression) -> GraphTensor { + assert_eq!(input.shape.len(), 4); // batch, n_heads, seq, head_dim + let (batch, n_heads, seq, head_dim) = input.dims4(); // Get freqs let freqs = - (input.graph().arange::>() * 2.0) / (HEAD_DIM as f32); - let freqs = 1000000_f32.pow(freqs); - let pos = input.graph().arange::() + prev_seq; - let emb = pos.expand::<(_, Const<1>), _>().matmul(freqs.expand()); + (input.graph().arange(head_dim / 2) * 2.0) / (head_dim.to_usize().unwrap() as f32); + let freqs = 500_000_f32.pow(freqs); + let pos = input.graph().arange(seq) + prev_seq; + let emb = pos.expand(1, 1).matmul(freqs.expand(0, seq)); // Split input into evens and odds - let split = - input.reshape::<(Batch, Const, Seq, Const, Const<2>)>(); - let x0: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = - split - .slice((.., .., .., .., ..Expression::from(1))) - .contiguous() - .realize(); - let x1: GraphTensor<(Batch, Const, Seq, Const, Const<1>)> = - split - .slice((.., .., .., .., Expression::from(1)..)) - .contiguous() - .realize(); + let split = input.reshape((batch, n_heads, seq, head_dim / 2, 2)); + let x0 = split.slice((.., .., .., .., ..1)); + let x1 = split.slice((.., .., .., .., 1..)); // Apply sin and cos embeddings - let x0_out = x0 * emb.cos().expand() - x1 * emb.sin().expand(); - let x1_out = x0 * emb.sin().expand() + x1 * emb.cos().expand(); + let x0_out = x0 * emb.cos().expand_to(x0.shape) - x1 * emb.sin().expand_to(x1.shape); + let x1_out = x0 * emb.sin().expand_to(x0.shape) + x1 * emb.cos().expand_to(x1.shape); // Combine back into output - x0_out - .concat_along::<(Batch, Const, Seq, Const, Const<2>), Axis<4>, _>( - x1_out, - ) - .reshape() + x0_out.concat_along(x1_out, 4).reshape(input.shape) } + pub struct SelfAttention { - pub q_proj: GraphTensor>, - pub k_proj: GraphTensor>, - pub v_proj: GraphTensor>, - pub o_proj: GraphTensor>, + pub q_proj: GraphTensor, // Hidden -> hidden + pub k_proj: GraphTensor, // Proj dim -> hidden + pub v_proj: GraphTensor, // Proj dim -> hidden + pub o_proj: GraphTensor, // Hidden -> hidden } - impl - Module<( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - PhantomData, - )> for SelfAttention - { - type Output = ( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - ); - fn forward( - &self, - (x, (k_cache, v_cache), _): ( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - PhantomData, - ), - ) -> Self::Output { + impl Module<(GraphTensor, KVCache)> for SelfAttention { + type Output = (GraphTensor, KVCache); + fn forward(&self, (x, (k_cache, v_cache)): (GraphTensor, KVCache)) -> Self::Output { + // x: batch, seq, hidden + let (batch, seq, _) = x.dims3(); + let (_, _, prev_seq, _) = k_cache.dims4(); // Apply the Projections let queries = x - .matmul(self.q_proj.permute()) - .reshape::<(Batch, CurSeq, Const, Const)>() - .permute::<_, Axes4<0, 2, 1, 3>>(); + .matmul(self.q_proj.permute((1, 0))) + .reshape((batch, seq, N_HEADS, HEAD_DIM)) + .permute((0, 2, 1, 3)); let keys = x - .matmul(self.k_proj.permute()) - .reshape::<(Batch, CurSeq, Const, Const)>() - .permute::<_, Axes4<0, 2, 1, 3>>(); + .matmul(self.k_proj.permute((1, 0))) + .reshape((batch, seq, N_KV_HEADS, HEAD_DIM)) + .permute((0, 2, 1, 3)); let values = x - .matmul(self.v_proj.permute()) - .reshape::<(Batch, CurSeq, Const, Const)>() - .permute::<_, Axes4<0, 2, 1, 3>>(); + .matmul(self.v_proj.permute((1, 0))) + .reshape((batch, seq, N_KV_HEADS, HEAD_DIM)) + .permute((0, 2, 1, 3)); // Rotary embed queries and keys - let queries = apply_rotary_embeddings_ggml(queries, PrevSeq::size().big()); - let keys = apply_rotary_embeddings_ggml(keys, PrevSeq::size().big()); + let queries = apply_rotary_embeddings_ggml(queries, prev_seq); + let keys = apply_rotary_embeddings_ggml(keys, prev_seq); // Add KV cache - let (keys, values) = ( - k_cache.concat_along::<_, Axis<2>, _>(keys), - v_cache.concat_along::<_, Axis<2>, _>(values), - ); + let keys = k_cache.concat_along(keys, 2); + let values = v_cache.concat_along(values, 2); // Repeat the KV States for Grouped-Query Attention - let repeated_keys = keys.expand::<(_, _, Const, _, _), _>(); - let repeated_values = values.expand::<(_, _, Const, _, _), _>(); + let repeated_keys = keys.expand(2, N_ATTENTION_GROUPS); + let repeated_values = values.expand(2, N_ATTENTION_GROUPS); // Calculate attention weights let mut attention_weights = queries - .reshape::<(_, Const, Const, _, _)>() // Split query heads into groups - .matmul(repeated_keys.permute()) - .div((HEAD_DIM as f32).sqrt()); + .reshape((batch, N_KV_HEADS, N_ATTENTION_GROUPS, seq, HEAD_DIM)) // Split query heads into groups + .matmul(repeated_keys.permute((0, 1, 2, 4, 3))) + / (HEAD_DIM as f32).sqrt(); - let attention_mask = self.k_proj.graph().triu::(1) * f16::MIN.to_f32(); + let attention_mask = self.k_proj.graph().triu(seq, 1) * f16::MIN.to_f32(); attention_weights += attention_mask - .pad::<(CurSeq, TotSeq)>(((0, 0), (TotSeq::size() - CurSeq::size(), 0))) - .expand(); + .pad(((0, 0), (prev_seq, 0))) + .expand(0, batch) + .expand(1, N_KV_HEADS) + .expand(2, N_ATTENTION_GROUPS); // Calculate final outputs let output = attention_weights - .softmax::>() + .softmax(4) // Apply distribution to values .matmul(repeated_values) // Merge heads - .permute::<_, Axes5<0, 3, 1, 2, 4>>() - .reshape::<(Batch, CurSeq, Const)>(); + .permute((0, 3, 1, 2, 4)) + .reshape((batch, seq, HIDDEN_DIM)); let output = output // Apply output projection - .matmul(self.o_proj.permute()); + .matmul(self.o_proj.permute((1, 0))); (output, (keys.contiguous(), values.contiguous())) // Cache needs to be contiguous for transferring to another graph } } - impl InitModule for SelfAttention { - fn initialize(cx: &mut Graph) -> Self { + impl SelfAttention { + pub fn new(cx: &mut Graph) -> Self { Self { - q_proj: cx - .named_tensor("Q Proj") - .set(random_vec(HIDDEN_DIM * HIDDEN_DIM)), - k_proj: cx - .named_tensor("K Proj") - .set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)), - v_proj: cx - .named_tensor("V Proj") - .set(random_vec(ATTN_PROJ_DIM * HIDDEN_DIM)), - o_proj: cx - .named_tensor("O Proj") - .set(random_vec(HIDDEN_DIM * HIDDEN_DIM)), + q_proj: cx.named_tensor("Q Proj", (HIDDEN_DIM, HIDDEN_DIM)), + k_proj: cx.named_tensor("K Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)), + v_proj: cx.named_tensor("V Proj", (ATTN_PROJ_DIM, HIDDEN_DIM)), + o_proj: cx.named_tensor("O Proj", (HIDDEN_DIM, HIDDEN_DIM)), } } + + fn initialize(self) -> Self { + self.k_proj.set(random_vec( + self.k_proj.shape.n_elements().to_usize().unwrap(), + )); + self.o_proj.set(random_vec( + self.o_proj.shape.n_elements().to_usize().unwrap(), + )); + self.v_proj.set(random_vec( + self.v_proj.shape.n_elements().to_usize().unwrap(), + )); + self.q_proj.set(random_vec( + self.q_proj.shape.n_elements().to_usize().unwrap(), + )); + self + } } impl SerializeModule for SelfAttention { @@ -903,35 +895,18 @@ mod tests { pub struct TransformerBlock { pub attention: SelfAttention, - pub attention_norm: LayerNorm, - pub feed_forward: Mlp, - pub feed_forward_norm: LayerNorm, + pub attention_norm: LayerNorm, + pub feed_forward: Mlp, + pub feed_forward_norm: LayerNorm, } - impl - Module<( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - PhantomData, - )> for TransformerBlock - { - type Output = ( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - ); - fn forward( - &self, - (mut x, cache, _): ( - GraphTensor<(Batch, CurSeq, Const)>, - KVCache, - PhantomData, - ), - ) -> Self::Output { + impl Module<(GraphTensor, KVCache)> for TransformerBlock { + type Output = (GraphTensor, KVCache); + fn forward(&self, (mut x, cache): (GraphTensor, KVCache)) -> Self::Output { // Attention - let normed = self.attention_norm.forward(x); let (y, cache) = self .attention - .forward((normed, cache, PhantomData::)); + .forward((self.attention_norm.forward(x), cache)); // Residual Addition x += y; @@ -944,94 +919,85 @@ mod tests { } } - impl InitModule for TransformerBlock { - fn initialize(cx: &mut Graph) -> Self { + impl TransformerBlock { + pub fn new(cx: &mut Graph) -> Self { Self { - attention: InitModule::initialize(cx), - attention_norm: LayerNorm::init(true, false, false, 1e-5, cx), - feed_forward: InitModule::initialize(cx), - feed_forward_norm: LayerNorm::init(true, false, false, 1e-5, cx), + attention: SelfAttention::new(cx), + attention_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx), + feed_forward: Mlp::new(HIDDEN_DIM, MLP_DIM, cx), + feed_forward_norm: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx), } } + + fn initialize(mut self) -> Self { + self.attention_norm = self.attention_norm.initialize(); + self.feed_forward_norm = self.feed_forward_norm.initialize(); + self.attention = self.attention.initialize(); + self.feed_forward.down_proj = self.feed_forward.down_proj.initialize(); + self.feed_forward.up_proj = self.feed_forward.up_proj.initialize(); + self.feed_forward.gate_proj = self.feed_forward.gate_proj.initialize(); + self + } } - pub struct MistralLM { + pub struct Llama { // Transformer layers pub layers: Vec, - // Final Norm layer - pub norm: LayerNorm, + // Norm + LM head + pub head: LayerNorm, } - impl - Module<( - GraphTensor<(Batch, CurSeq, Const)>, - Vec>, - PhantomData, - )> for MistralLM - { - type Output = ( - GraphTensor<(Batch, CurSeq, Const)>, - Vec>, - ); - fn forward( - &self, - (input, cache, _): ( - GraphTensor<(Batch, CurSeq, Const)>, - Vec>, - PhantomData, - ), - ) -> Self::Output { - let mut x = input; - + impl Module<(GraphTensor, &[KVCache])> for Llama { + type Output = (GraphTensor, Vec); + fn forward(&self, (mut x, cache): (GraphTensor, &[KVCache])) -> Self::Output { // Run through layers and collect new caches let mut new_caches = vec![]; let mut new_cache; for (i, layer) in self.layers.iter().enumerate() { - (x, new_cache) = layer.forward((x, cache[i], PhantomData::)); + (x, new_cache) = layer.forward((x, cache[i])); new_caches.push(new_cache); } // Run through last norm and output projection - let normed = self.norm.forward(x); - (normed, new_caches) + (self.head.forward(x), new_caches) } } - impl InitModule for MistralLM { - fn initialize(cx: &mut Graph) -> Self { + impl Llama { + pub fn new(cx: &mut Graph) -> Self { Self { - norm: LayerNorm::init(true, false, false, 1e-5, cx), - layers: (0..NUM_LAYERS) - .map(|_| InitModule::initialize(cx)) - .collect(), + head: LayerNorm::new(HIDDEN_DIM, true, false, false, 1e-5, cx), + layers: (0..NUM_LAYERS).map(|_| TransformerBlock::new(cx)).collect(), } } + + fn initialize(mut self) -> Self { + self.head = self.head.initialize(); + self.layers = self.layers.into_iter().map(|l| l.initialize()).collect(); + self + } } let mut cx = Graph::new(); - let model = MistralLM::initialize(&mut cx); + let model = Llama::new(&mut cx).initialize(); let caches = (0..NUM_LAYERS) .map(|_| { ( - cx.tensor::<(Const<1>, Const, Dyn<'p'>, Const)>() - .set_dyn( - random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), - &[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM], - ), - cx.tensor::<(Const<1>, Const, Dyn<'p'>, Const)>() - .set_dyn( - random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), - &[1, N_KV_HEADS, SEQ_LEN, HEAD_DIM], - ), + cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn( + random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), + (1, N_KV_HEADS, SEQ_LEN, HEAD_DIM), + ), + cx.tensor((1, N_KV_HEADS, 'p', HEAD_DIM)).set_dyn( + random_vec(SEQ_LEN * N_KV_HEADS * HEAD_DIM), + (1, N_KV_HEADS, SEQ_LEN, HEAD_DIM), + ), ) }) - .collect(); + .collect::>(); let input = cx - .tensor::<(Const<1>, Dyn<'s'>, luminal::shape::Const)>() - .set_dyn(random_vec(2 * HIDDEN_DIM), &[1, 2, HIDDEN_DIM]); - let (mut out, _) = model.forward((input, caches, PhantomData::>)); + .tensor((1, 's', HIDDEN_DIM)) + .set_dyn(random_vec(2 * HIDDEN_DIM), (1, 2, HIDDEN_DIM)); + let (mut out, _) = model.forward((input, &caches)); out.retrieve(); - - cx.set_dyn_dim('t', SEQ_LEN + 2); cx.execute(); let unopt_out = out.data(); @@ -1040,6 +1006,6 @@ mod tests { cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out); cx.execute(); - assert_close_precision(&out.data(), &unopt_out, 1e-1); + assert_close_precision(&out.data(), &unopt_out, 1e-2); } } diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index 80dbc5b0..cb83d19e 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -9,6 +9,9 @@ mod unary; pub use quantized::*; pub use cudarc::driver::CudaDevice; +pub use elementwise_fusion::ElementwiseFusionCompiler; +pub use other::*; +pub use prim::PrimitiveCompiler; #[cfg(test)] #[macro_use] @@ -126,13 +129,13 @@ impl CudaFloat for u8 { } } -fn expr_to_cuda_string(expr: &BigExpression) -> String { +fn expr_to_cuda_string(expr: &Expression) -> String { let mut symbols = vec![]; - for term in expr.clone().simplify().terms { + for term in expr.simplify().terms.read().iter() { let new_symbol = match term { Term::Num(n) => n.to_string(), Term::Var(c) => { - if c == 'z' { + if *c == 'z' { "(int)idx".to_string() } else { c.to_string() @@ -170,7 +173,7 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker]) -> (Vec, String) { let symbols: Vec = shapes .iter() .flat_map(|st| { - st.shape() + st.dims() .into_iter() .chain( st.padding diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index 5675d117..a0a529c4 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -21,7 +21,7 @@ crate::debug_type!(Matmul); impl Operator for Matmul { fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape()); + let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims()); let (batch_size, m, k, n) = ( a_shape .iter() diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index 24ba6232..e148a52e 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -16,7 +16,7 @@ use crate::{ pub struct CudaARange { function: CudaFunction, device: Arc, - pub size: BigExpression, + pub size: Expression, dyn_map: *const FxHashMap, _phantom: PhantomData, } @@ -25,7 +25,7 @@ crate::debug_type!(CudaARange); impl CudaARange { pub fn new( device: Arc, - size: BigExpression, + size: Expression, dyn_map: *const FxHashMap, ) -> Self { let type_name = T::type_name(); diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index ee26dc9d..b610ff3b 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -558,19 +558,19 @@ where let inp = get_buffer_from_tensor::(&tensors[0].0); let front_size: usize = tensors[0] .1 - .shape() + .dims() .iter() .take(self.dim) .map(|i| i.to_usize().unwrap()) .product(); let back_size: usize = tensors[0] .1 - .shape() + .dims() .iter() .skip(self.dim + 1) .map(|i| i.to_usize().unwrap()) .product(); - let dim_size = tensors[0].1.shape()[self.dim].to_usize().unwrap(); + let dim_size = tensors[0].1.dims()[self.dim].to_usize().unwrap(); let out = self.device.alloc_zeros::(inp_size).unwrap(); let mut params = vec![ @@ -648,19 +648,19 @@ impl Operator for CudaMaxReduce { let inp = get_buffer_from_tensor::(&tensors[0].0); let front_size: usize = tensors[0] .1 - .shape() + .dims() .iter() .take(self.dim) .map(|i| i.to_usize().unwrap()) .product(); let back_size: usize = tensors[0] .1 - .shape() + .dims() .iter() .skip(self.dim + 1) .map(|i| i.to_usize().unwrap()) .product(); - let dim_size = tensors[0].1.shape()[self.dim].to_usize().unwrap(); + let dim_size = tensors[0].1.dims()[self.dim].to_usize().unwrap(); let out = self.device.alloc_zeros::(inp_size).unwrap(); let mut params = vec![ @@ -703,7 +703,7 @@ impl Compiler for PrimitiveCompiler { // Create copy node let copy_node = graph .add_op(CudaCopyToDevice::::new(dev.clone())) - .input(function_node, 0, ShapeTracker::new(&[])) + .input(function_node, 0, ShapeTracker::new(())) .finish(); // Switch outgoing edges from input to copy_node @@ -732,7 +732,7 @@ impl Compiler for PrimitiveCompiler { { let copy_from_node = graph .add_op(CudaCopyFromDevice::::new(dev.clone())) - .input(source, 0, ShapeTracker::new(&[])) + .input(source, 0, ShapeTracker::new(())) .finish(); graph.add_edge(copy_from_node, function_node, edge_weight); graph.remove_edge(edge); diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index 306928e0..decfb6e6 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -122,13 +122,13 @@ impl Operator for QuantizedMatmul { let (a_shape, b_shape) = ( inp[0] .1 - .shape() + .dims() .into_iter() .map(|i| i.to_usize().unwrap()) .collect::>(), inp[1] .1 - .shape() + .dims() .into_iter() .map(|i| i.to_usize().unwrap()) .collect::>(), @@ -256,15 +256,9 @@ impl CudaQuantizedCompiler { impl Compiler for CudaQuantizedCompiler { type Output = (); - fn compile(&self, graph: &mut Graph, mut remap: To) { + fn compile(&self, graph: &mut Graph, _: To) { let device = CudaDevice::new(0).unwrap(); - let mut weight_ids = self.0.clone(); - let mut local_remap = remap.to_ids_mut(); - for w in &mut weight_ids { - local_remap.push(w); - } - // Normal compilation - graph.compile(crate::CudaCompiler::::default(), &mut local_remap); + let weight_ids = self.0.clone(); // Modify ops directly downstream of weights for weight in downstream(&weight_ids, graph) { for (target, (inp_ind, _, _)) in graph diff --git a/crates/luminal_cuda/src/tests/fp16.rs b/crates/luminal_cuda/src/tests/fp16.rs index dbed9a6b..c0c5a84a 100644 --- a/crates/luminal_cuda/src/tests/fp16.rs +++ b/crates/luminal_cuda/src/tests/fp16.rs @@ -1,34 +1,64 @@ use dfdx::prelude::{Module as DfdxModule, *}; -use itertools::Itertools; -use num_traits::Float; use rand::{rngs::StdRng, SeedableRng}; use luminal::{module::Module, prelude::*}; - -#[allow(unused_imports)] -use dfdx::prelude::{ - Axes as DAxes, Axes2 as DAxes2, Axes3 as DAxes3, Axes4 as DAxes4, Axes5 as DAxes5, - Axis as DAxis, Const as DConst, *, -}; -#[allow(unused_imports)] -use luminal::{ - prelude::{ - Axes as LAxes, Axes2 as LAxes2, Axes3 as LAxes3, Axes4 as LAxes4, Axes5 as LAxes5, - Axis as LAxis, Const as LConst, *, - }, - tests::{ - assert_close, assert_close_precision, assert_exact, random_vec, random_vec_rng, test_graphs, - }, -}; - -use crate::CudaCompiler; +use luminal_nn::{Conv1D, LayerNorm, Linear, ReLU}; + +use crate::{binary_test, unary_test, CudaCompiler}; +luminal::test_imports!(); + +unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f16); +unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f16); +unary_test!(|a| a.recip(), |a| a.recip(), test_recip, f16); +unary_test!(|a| a * a, |a| a.clone() * a, test_square, f16); +unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f16); +unary_test!( + |a| a.log2(), + |a| (a.to_dtype::().ln() / 2_f32.ln()).to_dtype::(), + test_log2, + f16 +); +unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f16); +unary_test!( + |a| a.softmax(0), + |a| a.softmax::>(), + test_softmax, + f16 +); +unary_test!( + |a| a.layer_norm(0, 1e-5), + |a| a + .to_dtype::() + .normalize::>(1e-5) + .to_dtype::(), + test_norm, + f16 +); + +binary_test!(|a, b| a + b, |a, b| a + b, test_add, f16); +binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f16); +binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f16); +binary_test!(|a, b| a / b, |a, b| a / b, test_div, f16); +binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16); +binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f16); +binary_test!( + |a, b| a % b, + |a, b| (a.clone().to_dtype::() + - ((a.to_dtype::() / b.clone().to_dtype::()) + .to_dtype::() + .to_dtype::() + * b.to_dtype::())) + .to_dtype::(), + test_mod, + f16 +); #[test] fn test_contiguous() { let mut cx = Graph::new(); let data = random_vec(12); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.permute::, _>().reshape::>().retrieve(); + let a = cx.tensor((3, 4)).set(data.clone()); + let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve(); cx.compile(CudaCompiler::::default(), &mut b); cx.execute(); @@ -44,26 +74,20 @@ fn test_contiguous() { #[test] fn test_rotate() { let mut cx = Graph::new(); - const D: usize = 2; - const S: usize = 2; - const H: usize = 2; - let data = random_vec(D * S * H); - let a = cx - .tensor::>() - .set(data) - .keep() - .permute::<_, LAxes4<0, 2, 1, 3>>(); - let x1 = a.slice((.., .., .., ..Expression::from(H / 2))); - let x2 = a.slice((.., .., .., Expression::from(H / 2)..)); - let mut rotated_a = (-x2) - .concat_along::, LAxis<3>, _>(x1) - .retrieve(); + const B: usize = 2; + const F: usize = 3; + const D: usize = 4; + let data = random_vec(D * B * F); + let a = cx.tensor((F, B, D)).set(data.clone()).permute((1, 0, 2)); + let x1 = a.slice((.., .., ..D / 2)); + let x2 = a.slice((.., .., D / 2..)); + let mut rotated_a = (-x2).concat_along(x1, 1).retrieve(); cx.execute(); let unopt = rotated_a.data(); + rotated_a.drop(); cx.compile(CudaCompiler::::default(), &mut rotated_a); cx.execute(); - assert_close(&unopt, &rotated_a.data()); } @@ -83,140 +107,68 @@ fn test_constant() { assert_exact(&a.data(), &[625.0]); } -#[test] -fn test_log2() { - let mut cx = Graph::new(); - let data = random_vec(3); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.log2().retrieve(); - - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - assert_close( - &b.data(), - &data - .into_iter() - .map(|i| f16::from_f32(i).log2().to_f32()) - .collect::>(), - ); -} - -#[test] -fn test_exp2() { - let mut cx = Graph::new(); - let data = random_vec(3); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.exp2().retrieve(); - - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - assert_close( - &b.data(), - &data.into_iter().map(|i: f32| i.exp2()).collect::>(), - ); -} - -#[test] -fn test_mod() { - let mut cx = Graph::new(); - let a_data = random_vec(3); - let b_data = random_vec(3); - let a = cx.tensor::>().set(a_data.clone()); - let b = cx.tensor::>().set(b_data.clone()); - let mut c = a % b; - c.retrieve(); - - cx.compile(CudaCompiler::::default(), &mut c); - cx.execute(); - - // No dfdx equivalent - - assert_close( - &c.data(), - &a_data - .into_iter() - .zip(b_data) - .map(|(a, b)| a % b) - .collect_vec(), - ); -} - // Reduction op tests #[test] fn test_sum_reduce() { let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve(); + let a = cx.tensor((1, 10, 4096)).set(data.clone()); + let mut b = a.sum_reduce(2).retrieve(); + let mut c = a.sum_reduce(1).retrieve(); + let mut d = a.sum_reduce(0).retrieve(); cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>)); + let d_a = d_dev + .tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>)) + .to_dtype::(); let d_b = d_a.clone().sum::<_, DAxis<2>>(); let d_c = d_a.clone().sum::<_, DAxis<1>>(); let d_d = d_a.sum::<_, DAxis<0>>(); - assert_close_precision( - &b.data(), - &d_b.to_dtype::().to_dtype::().as_vec(), - 0.1, - ); - assert_close_precision( - &c.data(), - &d_c.to_dtype::().to_dtype::().as_vec(), - 0.1, - ); - assert_close_precision( - &d.data(), - &d_d.to_dtype::().to_dtype::().as_vec(), - 0.1, - ); + assert_close(&b.data(), &d_b.to_dtype::().as_vec()); + assert_close(&c.data(), &d_c.to_dtype::().as_vec()); + assert_close(&d.data(), &d_d.to_dtype::().as_vec()); } #[test] fn test_sum_reduce2() { let mut cx = Graph::new(); let data = random_vec(32 * 10 * 10 * 128); - let a = cx.tensor::>().set(data.clone()); - let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve(); + let a = cx.tensor((1, 32, 10, 10, 128)).set(data.clone()); + let mut d = a.sum_reduce(2).retrieve(); cx.compile(CudaCompiler::::default(), &mut d); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec( - data, - ( - DConst::<1>, - DConst::<32>, - DConst::<10>, - DConst::<10>, - DConst::<128>, - ), - ); + let d_a = d_dev + .tensor_from_vec( + data, + ( + DConst::<1>, + DConst::<32>, + DConst::<10>, + DConst::<10>, + DConst::<128>, + ), + ) + .to_dtype::(); let d_d = d_a.sum::<_, DAxis<2>>(); - assert_close_precision( - &d.data(), - &d_d.to_dtype::().to_dtype::().as_vec(), - 0.1, - ); + assert_exact(&d.data(), &d_d.to_dtype::().as_vec()); } #[test] fn test_max_reduce() { let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.max_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.max_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.max_reduce::<_, LAxis<0>>().retrieve(); + let a = cx.tensor((1, 10, 4096)).set(data.clone()); + let mut b = a.max_reduce(2).retrieve(); + let mut c = a.max_reduce(1).retrieve(); + let mut d = a.max_reduce(0).retrieve(); cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); cx.execute(); @@ -237,10 +189,10 @@ fn test_max_reduce() { fn test_mean_reduce() { let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve(); + let a = cx.tensor((1, 10, 4096)).set(data.clone()); + let mut b = a.mean_reduce(2).retrieve(); + let mut c = a.mean_reduce(1).retrieve(); + let mut d = a.mean_reduce(0).retrieve(); cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); cx.execute(); @@ -262,8 +214,8 @@ fn test_matmul_simple() { let mut cx = Graph::new(); let a_data = random_vec(256 * 256); let b_data = random_vec(256 * 256); - let a = cx.tensor::>().set(a_data.clone()); - let b = cx.tensor::>().set(b_data.clone()); + let a = cx.tensor((256, 256)).set(a_data.clone()); + let b = cx.tensor((256, 256)).set(b_data.clone()); let mut c = a.matmul(b).retrieve(); cx.compile(CudaCompiler::::default(), &mut c); @@ -274,33 +226,33 @@ fn test_matmul_simple() { let d_b = d_dev.tensor_from_vec(b_data, (DConst::<256>, DConst::<256>)); let d_c = d_a.to_dtype::().matmul(d_b.to_dtype::()); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1.); // Why is this imprecise? + assert_close(&c.data(), &d_c.to_dtype::().as_vec()); } #[test] fn test_matmul() { let d_dev = Cpu::default(); let mut cx = Graph::new(); - let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>(); - let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>(); + let mut rng = StdRng::seed_from_u64(0); + let a = cx.tensor(('M', 'K')); + let b = cx.tensor(('K', 'N')); let mut c = a.matmul(b).retrieve(); - cx.compile(CudaCompiler::::default(), &mut c); - let mut rng = StdRng::seed_from_u64(0); + cx.compile(CudaCompiler::::default(), &mut c); for m in (1..23).step_by(4) { for k in (1..35).step_by(3) { for n in (1..70).step_by(7) { let a_data = random_vec_rng(m * k, &mut rng); let b_data = random_vec_rng(k * n, &mut rng); - a.set_dyn(a_data.clone(), &[m, k]); - b.set_dyn(b_data.clone(), &[k, n]); + a.set_dyn(a_data.clone(), (m, k)); + b.set_dyn(b_data.clone(), (k, n)); + cx.execute(); let d_a = d_dev.tensor_from_vec(a_data, (m, k)); let d_b = d_dev.tensor_from_vec(b_data, (k, n)); let d_c = d_a.matmul(d_b); - - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1e-2); c.drop(); } } @@ -314,11 +266,11 @@ fn test_attn_matmul() { let a_data = random_vec_rng(32 * 11 * 128, &mut rng); let b_data = random_vec_rng(32 * 11 * 128, &mut rng); let a = cx - .named_tensor::>("Input") + .named_tensor("Input", (1, 32, 11, 128)) .set(a_data.clone()) .keep(); let b = cx - .named_tensor::>("Input") + .named_tensor("Input", (1, 32, 128, 11)) .set(b_data.clone()) .keep(); let mut c = a.matmul(b).retrieve(); @@ -340,7 +292,7 @@ fn test_attn_matmul() { ) .to_dtype::(); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); + assert_exact(&c.data(), &d_c.to_dtype::().as_vec()); } #[test] @@ -348,19 +300,18 @@ fn test_batch_matmul() { let m = 12; let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); - let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>(); - let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>(); + let a = cx.tensor(('B', 'M', 'K')); + let b = cx.tensor(('K', 'N')); let mut c = a.matmul(b).retrieve(); cx.compile(CudaCompiler::::default(), &mut c); - for batch in (1..23).step_by(4) { + for batch in (2..23).step_by(4) { for k in (1..35).step_by(3) { for n in (1..48).step_by(7) { let a_data = random_vec_rng(batch * m * k, &mut rng); let b_data = random_vec_rng(k * n, &mut rng); - a.set_dyn(a_data.clone(), &[batch, m, k]); - b.set_dyn(b_data.clone(), &[k, n]); - + a.set_dyn(a_data.clone(), (batch, m, k)); + b.set_dyn(b_data.clone(), (k, n)); cx.execute(); let d_dev = Cpu::default(); @@ -368,7 +319,7 @@ fn test_batch_matmul() { let d_b = d_dev.tensor_from_vec(b_data, (k, n)); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1e-2); c.drop(); } } @@ -379,27 +330,24 @@ fn test_batch_matmul() { fn test_batch_matmul_transpose() { const B: usize = 1; const M: usize = 48; // Any - const K: usize = 256; // >= 16, multiple of 16 - const N: usize = 256; // >= 256, multiple of 256 + const K: usize = 4096; // >= 16, multiple of 16 + const N: usize = 4096; // >= 256, multiple of 256 let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); let a_data = random_vec_rng(B * M * K, &mut rng); - let a = cx.named_tensor::>("A").set(a_data.clone()); + let a = cx.named_tensor("A", (B, M, K)).set(a_data.clone()); let b_data = random_vec_rng(K * N, &mut rng); - let b = cx.named_tensor::>("B").set(b_data.clone()); + let b = cx.named_tensor("B", (N, K)).set(b_data.clone()); let a_t_data = random_vec_rng(B * K * M, &mut rng); - let a_t = cx.named_tensor::>("A_T").set(a_t_data.clone()); + let a_t = cx.named_tensor("A_T", (B, K, M)).set(a_t_data.clone()); let b_t_data = random_vec_rng(K * N, &mut rng); - let b_t = cx.named_tensor::>("B_T").set(b_t_data.clone()); + let b_t = cx.named_tensor("B_T", (K, N)).set(b_t_data.clone()); - let mut a_b = a.matmul(b.permute::<_, LAxes2<1, 0>>()).retrieve(); + let mut a_b = a.matmul(b.permute((1, 0))).retrieve(); let mut a_b_t = a.matmul(b_t).retrieve(); - let mut a_t_b = a_t - .permute::<_, LAxes3<0, 2, 1>>() - .matmul(b.permute::<_, LAxes2<1, 0>>()) - .retrieve(); - let mut a_t_b_t = a_t.permute::<_, LAxes3<0, 2, 1>>().matmul(b_t).retrieve(); + let mut a_t_b = a_t.permute((0, 2, 1)).matmul(b.permute((1, 0))).retrieve(); + let mut a_t_b_t = a_t.permute((0, 2, 1)).matmul(b_t).retrieve(); cx.compile( <(GenericCompiler, CudaCompiler)>::default(), @@ -412,18 +360,22 @@ fn test_batch_matmul_transpose() { let d_b = d_dev.tensor_from_vec(b_data, (DConst::, DConst::)); let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::, DConst::, DConst::)); let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::, DConst::)); - let d_a_b = d_a.clone().matmul(d_b.clone().permute::<_, DAxes2<1, 0>>()); + let d_a_b = d_a + .clone() + .matmul(d_b.clone().permute::<_, dfdx::shapes::Axes2<1, 0>>()); let d_a_b_t = d_a.matmul(d_b_t.clone()); let d_a_t_b = d_a_t .clone() - .permute::<_, DAxes3<0, 2, 1>>() - .matmul(d_b.permute::<_, DAxes2<1, 0>>()); - let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t); - - assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 0.1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 0.1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 0.1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 0.1); + .permute::<_, dfdx::shapes::Axes3<0, 2, 1>>() + .matmul(d_b.permute::<_, dfdx::shapes::Axes2<1, 0>>()); + let d_a_t_b_t = d_a_t + .permute::<_, dfdx::shapes::Axes3<0, 2, 1>>() + .matmul(d_b_t); + + assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 1e-1); + assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 1e-1); + assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 1e-1); + assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 1e-1); } #[test] @@ -435,21 +387,18 @@ fn test_matmul_transpose() { let mut rng = StdRng::seed_from_u64(0); let a_data = random_vec_rng(M * K, &mut rng); - let a = cx.tensor::>().set(a_data.clone()); + let a = cx.tensor((M, K)).set(a_data.clone()); let b_data = random_vec_rng(K * N, &mut rng); - let b = cx.tensor::>().set(b_data.clone()); + let b = cx.tensor((N, K)).set(b_data.clone()); let a_t_data = random_vec_rng(K * M, &mut rng); - let a_t = cx.tensor::>().set(a_t_data.clone()); + let a_t = cx.tensor((K, M)).set(a_t_data.clone()); let b_t_data = random_vec_rng(K * N, &mut rng); - let b_t = cx.tensor::>().set(b_t_data.clone()); + let b_t = cx.tensor((K, N)).set(b_t_data.clone()); - let mut a_b = a.matmul(b.permute()).retrieve(); + let mut a_b = a.matmul(b.permute((1, 0))).retrieve(); let mut a_b_t = a.matmul(b_t).retrieve(); - let mut a_t_b = a_t - .permute::<_, LAxes2<1, 0>>() - .matmul(b.permute()) - .retrieve(); - let mut a_t_b_t = a_t.permute::<_, LAxes2<1, 0>>().matmul(b_t).retrieve(); + let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve(); + let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve(); cx.compile( <(GenericCompiler, CudaCompiler)>::default(), @@ -474,14 +423,16 @@ fn test_matmul_transpose() { let d_a_b_t = d_a.matmul(d_b_t.clone()); let d_a_t_b = d_a_t .clone() - .permute::<_, DAxes2<1, 0>>() + .permute::<_, dfdx::shapes::Axes2<1, 0>>() .matmul(d_b.permute()); - let d_a_t_b_t = d_a_t.permute::<_, DAxes2<1, 0>>().matmul(d_b_t); - - assert_close_precision(&a_b.data(), &d_a_b.to_dtype::().as_vec(), 0.1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.to_dtype::().as_vec(), 0.1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.to_dtype::().as_vec(), 0.1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.to_dtype::().as_vec(), 0.1); + let d_a_t_b_t = d_a_t + .permute::<_, dfdx::shapes::Axes2<1, 0>>() + .matmul(d_b_t); + + assert_close(&a_b.data(), &d_a_b.to_dtype::().as_vec()); + assert_close(&a_b_t.data(), &d_a_b_t.to_dtype::().as_vec()); + assert_close(&a_t_b.data(), &d_a_t_b.to_dtype::().as_vec()); + assert_close(&a_t_b_t.data(), &d_a_t_b_t.to_dtype::().as_vec()); } #[test] @@ -491,16 +442,14 @@ fn test_relu_and_linear() { let input_data = random_vec(32); let w1 = random_vec(32 * 64); let w2 = random_vec(32 * 64); - let batch = cx - .named_tensor::>("Batch") - .set(random_vec(32 * 2)); - let a = cx.named_tensor::>("Single").set(input_data.clone()); - - let model: ( - luminal_nn::Linear<32, 64>, - luminal_nn::ReLU, - luminal_nn::Linear<64, 32>, - ) = InitModule::initialize(&mut cx); + let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2)); + let a = cx.named_tensor("Single", 32).set(input_data.clone()); + + let model = ( + Linear::new(32, 64, false, &mut cx), + ReLU, + Linear::new(64, 32, false, &mut cx), + ); model.0.weight.set(w1.clone()); model.2.weight.set(w2.clone()); let mut b = model.forward(a).retrieve(); @@ -517,8 +466,8 @@ fn test_relu_and_linear() { ); cx.execute(); - assert_close_precision(&unoptimized_b, &b.data(), 0.01); - assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 0.01); + assert_close_precision(&unoptimized_b, &b.data(), 1e-2); + assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 1e-2); // Test against dfdx let dev = Cpu::default(); @@ -529,19 +478,19 @@ fn test_relu_and_linear() { )>::build_on_device(&dev); // Set weights model.0.weight = dev - .tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>)) + .tensor_from_vec(w1, (DConst::<32>, DConst::<64>)) .permute() .to_dtype::(); model.2.weight = dev - .tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>)) + .tensor_from_vec(w2, (DConst::<64>, DConst::<32>)) .permute() .to_dtype::(); let a = dev - .tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,)) + .tensor_from_vec(input_data, (DConst::<32>,)) .to_dtype::(); let out = model.forward(a); - assert_close_precision(&unoptimized_b, &out.to_dtype::().as_vec(), 0.01); + assert_close_precision(&unoptimized_b, &out.to_dtype::().as_vec(), 1e-2); } #[test] @@ -551,9 +500,9 @@ fn test_rms_norm() { let inp_data = random_vec_rng(15 * 32, &mut rng); let weight_data = random_vec_rng(32, &mut rng); let mut cx = Graph::new(); - let a = cx.tensor::>().set(inp_data.clone()); + let a = cx.tensor((15, 32)).set(inp_data.clone()); - let model = luminal_nn::LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx); + let model = LayerNorm::new(32, true, false, false, 1e-5, &mut cx); model.weight.unwrap().set(weight_data.clone()); let mut b = model.forward(a).retrieve(); @@ -576,91 +525,154 @@ fn test_rms_norm() { assert_close(&b.data(), &out.to_dtype::().as_vec()); } +#[test] +fn test_layer_norm() { + let mut cx = Graph::new(); + let a_data = random_vec(15 * 16 * 32); + let a = cx.tensor((15, 16, 32)).set(a_data.clone()); + let mut b = a.layer_norm(0, 1e-5).retrieve(); + let mut c = a.layer_norm(2, 1e-5).retrieve(); + cx.compile( + <(GenericCompiler, CudaCompiler)>::default(), + (&mut b, &mut c), + ); + cx.execute(); + + let d_dev = Cpu::default(); + let d_a = d_dev.tensor_from_vec(a_data, (DConst::<15>, DConst::<16>, DConst::<32>)); + let d_b = d_a.clone().normalize::>(1e-5); + let d_c = d_a.normalize::>(1e-5); + + assert_close_precision(&b.data(), &d_b.as_vec(), 1e-2); + assert_close_precision(&c.data(), &d_c.as_vec(), 1e-2); +} + #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); - let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx); - let w_k_weight = random_vec(32 * 32); - model.attention.w_k.weight.set(w_k_weight.clone()); - let w_q_weight = random_vec(32 * 32); - model.attention.w_q.weight.set(w_q_weight.clone()); - let w_v_weight = random_vec(32 * 32); - model.attention.w_v.weight.set(w_v_weight.clone()); - let w_o_weight = random_vec(32 * 32); - model.attention.w_o.weight.set(w_o_weight.clone()); - let ff_0_weight = random_vec(32 * 64); - model.ff.0.weight.set(ff_0_weight.clone()); - let ff_1_weight = random_vec(64 * 32); - model.ff.2.weight.set(ff_1_weight.clone()); - - let a_data = random_vec(2 * 32); + let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx); + model + .attention + .w_k + .weight + .set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]); + model + .attention + .w_q + .weight + .set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]); + model + .attention + .w_v + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]); + model + .attention + .w_o + .weight + .set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]); + model + .ff + .0 + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]); + model + .ff + .2 + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]); + let a = cx - .tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>() - .set_dyn(a_data.clone(), &[1, 2, 3]) - .keep(); - cx.keep_tensors(params(&model)); + .tensor(('a', 3)) + .set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3)); let mut b = model.forward(a).retrieve(); - cx.execute(); - let unopt_b = b.data(); - b.drop(); cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut b); cx.execute(); - assert_close_precision(&unopt_b, &b.data(), 0.01); let d_dev = Cpu::default(); - let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> = - d_dev - .build_module::, f32>(); - d_model.self_attn.w_k.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_v.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_q.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_o.bias.copy_from(&[0.; 32]); + let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> = + d_dev.build_module::, f32>(); + d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]); d_model.self_attn.w_o.weight = d_dev - .tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![1., 22., 3., 1., 2., 3., 1., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_k.weight = d_dev - .tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![1., 22., 3., 1., 2., 3., 1., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_q.weight = d_dev - .tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_v.weight = d_dev - .tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.ff.0 .0.weight = d_dev - .tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.], + (DConst::<3>, DConst::<4>), + ) .permute(); - d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,)); + d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (DConst::<4>,)); d_model.ff.0 .2.weight = d_dev - .tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.], + (DConst::<4>, DConst::<3>), + ) .permute(); - d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); - d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,)); - d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,)); + d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); + d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,)); + d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,)); d_model.norm1.epsilon = 1e-5; - d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); - d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); + d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); + d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); d_model.norm2.epsilon = 1e-5; - let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>)); + let d_a = d_dev.tensor_from_vec(vec![-1., 2., 3., 3., 3., -1.], (DConst::<2>, DConst::<3>)); let d_b = d_model.forward(d_a); - assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); + assert_close(&b.data(), &d_b.as_vec()); +} + +#[test] +fn test_common_buffer() { + let data = random_vec(32); + let mut cx = Graph::new(); + let a = cx.tensor(32); + a.set(data.clone()); + let a1 = cx.tensor(32); + a1.set(data.clone()); + let exped = a * a1; + let mut b = exped.log2().retrieve(); + let mut c = exped.sin().retrieve(); + + cx.compile(CudaCompiler::::default(), (&mut b, &mut c)); + cx.execute(); } #[test] fn test_embedding() { let mut cx = Graph::new(); let batch = cx - .named_tensor::>("Batch") + .named_tensor("Batch", (2, 3)) .set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]) .keep(); - let a = cx - .named_tensor::>("Single") - .set(vec![1.0, 0.0, 1.0]) - .keep(); + let a = cx.named_tensor("Single", 3).set(vec![1.0, 0.0, 1.0]).keep(); - let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx); + let model = luminal_nn::Embedding::new(3, 4, &mut cx); model .weight .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]); @@ -691,12 +703,8 @@ fn test_embedding() { fn test_slice() { let data = random_vec(256); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut c: GraphTensor> = a - .slice((..Expression::from(20),)) - .realize() - .contiguous() - .retrieve(); + let a = cx.tensor(256).set(data.clone()); + let mut c = a.slice(..20).contiguous().retrieve(); cx.compile(CudaCompiler::::default(), &mut c); cx.execute(); @@ -715,8 +723,8 @@ fn test_pad() { // Pad a 8x2 mat to 10x4 let data = random_vec(8 * 2); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut c = a.pad::>(((0, 2), (0, 2))).contiguous().retrieve(); + let a = cx.tensor((8, 2)).set(data.clone()); + let mut c = a.pad(((0, 2), (0, 2))).contiguous().retrieve(); cx.compile(CudaCompiler::::default(), &mut c); cx.execute(); @@ -737,16 +745,12 @@ fn test_pad_contig() { let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); let a_data = random_vec_rng(m * k, &mut rng); - let mut a = cx - .tensor::<(Dyn<'M'>, Dyn<'K'>)>() - .set_dyn(a_data, &[m, k]) - .retrieve(); - let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a - .pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')]) + let mut a = cx.tensor(('M', 'K')).set_dyn(a_data, (m, k)).retrieve(); + let mut b = a + .pad(((0, 0), (0, Expression::from(24) - 'K'))) .contiguous() .retrieve(); - let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = - (a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve(); + let mut c = (a.slice((.., ..k)) / 1.0).retrieve(); cx.compile(CudaCompiler::::default(), (&mut a, &mut b, &mut c)); cx.execute(); @@ -760,13 +764,9 @@ fn test_pad_contig() { fn test_movement() { let data = random_vec(32); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let b: GraphTensor> = a.pad(&[(0, 10)]).contiguous().retrieve(); - let mut c: GraphTensor> = b - .slice((..Expression::from(25),)) - .realize() - .contiguous() - .retrieve(); + let a = cx.tensor(32).set(data.clone()); + let b = a.pad((0, 10)).contiguous().retrieve(); + let mut c = b.slice((..25,)).contiguous().retrieve(); cx.compile(CudaCompiler::::default(), &mut c); cx.execute(); @@ -779,3 +779,147 @@ fn test_movement() { assert_exact(&c.data(), &d_c.as_vec()); } + +#[test] +fn test_slice_add() { + let mut cx = Graph::new(); + let a = cx.tensor(256).set(random_array::<256>()); + let mut b = (a.slice(0..64) + a.slice(64..128) + a.slice(128..192) + a.slice(192..256)) + .expand(0, 4) + .retrieve(); + + cx.compile(CudaCompiler::::default(), &mut b); + cx.execute(); +} + +#[test] +fn test_conv2d() { + let mut cx = Graph::new(); + + const CH_IN: usize = 5; + const CH_OUT: usize = 2; + const KERNELX: usize = 2; + const KERNELY: usize = 2; + const STRIDEX: usize = KERNELX; + const STRIDEY: usize = KERNELY; + const DILATIONX: usize = 1; + const DILATIONY: usize = 1; + const DIMX_IN: usize = 16; + const DIMY_IN: usize = 9; + + let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![ + 8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8., + 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7., + 1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5., + 8., 5., 9., 0., 9., 5., 6., 8., 9., 5., 4., 1., 9., 7., 2., 2., 7., 9., 3., 1., 2., 8., 4., + 0., 8., 0., 5., 6., 7., 7., 4., 3., 4., 6., 8., 3., 7., 8., 8., 7., 1., 5., 1., 8., 0., 1., + 1., 7., 3., 2., 1., 0., 4., 5., 4., 3., 2., 5., 4., 2., 4., 1., 9., 4., 1., 9., 7., 7., 1., + 2., 6., 3., 4., 1., 1., 6., 6., 8., 2., 7., 7., 9., 0., 9., 0., 1., 4., 2., 4., 9., 6., 8., + 6., 1., 6., 3., 8., 3., 4., 5., 0., 2., 1., 8., 2., 2., 8., 7., 0., 7., 7., 3., 4., 5., 0., + 7., 2., 1., 1., 4., 2., 9., 9., 6., 1., 5., 4., 6., 9., 5., 4., 1., 9., 1., 5., 5., 5., 8., + 8., 0., 1., 3., 0., 8., 8., 5., 1., 6., 1., 5., 6., 4., 4., 4., 0., 1., 1., 5., 1., 7., 2., + 3., 5., 5., 4., 9., 1., 3., 7., 6., 7., 1., 5., 3., 8., 6., 6., 6., 7., 3., 2., 2., 8., 1., + 3., 0., 2., 7., 6., 5., 7., 5., 7., 8., 1., 2., 2., 5., 0., 2., 9., 1., 5., 3., 8., 7., 9., + 7., 2., 8., 8., 8., 6., 3., 2., 7., 7., 0., 3., 7., 8., 3., 7., 2., 3., 2., 7., 5., 5., 6., + 0., 9., 0., 9., 9., 1., 8., 7., 9., 6., 8., 7., 5., 4., 9., 5., 6., 3., 2., 8., 3., 0., 6., + 3., 8., 3., 1., 8., 7., 2., 0., 7., 7., 7., 7., 8., 0., 4., 9., 8., 2., 0., 4., 4., 3., 5., + 5., 3., 0., 3., 6., 3., 1., 2., 9., 9., 6., 8., 1., 2., 6., 8., 6., 0., 0., 2., 8., 8., 5., + 0., 5., 9., 0., 8., 1., 1., 3., 5., 9., 3., 5., 8., 6., 3., 2., 9., 4., 8., 3., 9., 5., 2., + 9., 0., 1., 6., 8., 0., 3., 0., 1., 2., 1., 0., 1., 4., 1., 1., 0., 6., 9., 2., 7., 2., 6., + 0., 4., 8., 2., 6., 7., 2., 2., 7., 4., 5., 8., 1., 4., 7., 5., 9., 7., 2., 5., 9., 1., 6., + 1., 7., 9., 5., 6., 9., 3., 5., 1., 6., 1., 3., 3., 9., 3., 9., 0., 1., 8., 1., 9., 8., 5., + 3., 4., 4., 1., 5., 5., 4., 4., 5., 8., 7., 1., 1., 7., 3., 9., 0., 1., 3., 4., 8., 4., 0., + 5., 6., 2., 0., 7., 8., 2., 6., 2., 9., 6., 2., 0., 3., 7., 5., 7., 1., 8., 5., 5., 9., 1., + 0., 3., 5., 7., 5., 3., 2., 8., 6., 3., 0., 5., 8., 5., 7., 8., 8., 2., 9., 0., 1., 8., 6., + 0., 3., 2., 5., 2., 9., 8., 9., 6., 2., 0., 3., 2., 5., 9., 1., 3., 6., 5., 2., 8., 2., 2., + 1., 8., 6., 4., 1., 6., 0., 7., 3., 0., 9., 6., 5., 5., 5., 2., 4., 2., 8., 3., 0., 6., 3., + 8., 8., 4., 9., 4., 7., 0., 3., 5., 1., 4., 6., 0., 0., 5., 9., 7., 8., 6., 7., 0., 6., 7., + 0., 5., 8., 8., 6., 4., 6., 0., 2., 3., 2., 8., 7., 5., 9., 6., 6., 2., 0., 4., 4., 4., 4., + 2., 7., 5., 3., 2., 6., 3., 7., 0., 7., 2., 5., 1., 4., 4., 5., 1., 6., 7., 5., 7., 0., 7., + 8., 4., 7., 3., 9., 1., 7., 5., 6., 1., 0., 2., 0., 0., 5., 5., 8., 8., 7., 3., 7., 2., 9., + 3., 8., 4., 5., 3., 8., 5., 2., 0., 2., 0., 5., 9., 0., 3., 8., 0., 4., 1., 8., 4., 8., 9., + 1., 1., 4., 5., 0., 2., 0., 9., 4., 2., 3., 9., 0., 7., 3., 1., 5., 9., 1., 6., 5., 4., 2., + 1., 2., 1., 1., 4., 7., 2., + ]); + + let model = luminal_nn::Conv2D::new( + CH_IN, + CH_OUT, + (KERNELX, KERNELY), + (STRIDEX, STRIDEY), + (DILATIONX, DILATIONY), + false, + &mut cx, + ); + model.weight.set(vec![ + 0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300, + 0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500, + 0.0700, -0.0800, 0.1700, 0.1000, -0.0700, 0.1600, -0.1600, -0.1900, -0.0500, -0.2100, + 0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400, + ]); + + let mut out1 = model.forward(inp1).retrieve(); + + cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out1); + cx.execute(); + + assert_close_precision( + &out1.data(), + &[ + 3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700, + 4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200, + -0.7100, -0.6500, 4.3900, 0.4000, 1.0300, 0.9800, 3.1200, 2.7400, 2.5100, 0.1200, + 1.8500, 2.0000, -0.7900, 1.0700, -0.3900, -0.8100, -2.5100, -2.9700, 0.2100, 1.8400, + -0.7700, -0.3900, 1.2200, 0.1900, 4.1700, -4.3600, -1.8600, 0.4800, -2.4400, 2.6300, + 1.5000, -1.9700, 1.2800, -2.8200, -2.3200, 0.2200, -0.3800, 2.1800, -0.8200, -1.5700, + 1.2000, -3.4200, -1.6700, 0.9000, + ], + 1e-2, + ); +} + +#[test] +fn test_conv1d_pad_stride() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + + const CH_IN: usize = 80; + const CH_OUT: usize = 384; + const KERNEL: usize = 3; + const STRIDE: usize = 1; + const PADDING: usize = 1; + const DILATION: usize = 1; + const DIM_IN: usize = 10; + let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng); + let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng); + + let model = Conv1D::new( + CH_IN, CH_OUT, KERNEL, STRIDE, DILATION, PADDING, false, &mut cx, + ); + model.weight.set(kernel_data.clone()); + + let inp1 = cx + .tensor((CH_IN, 's')) + .set_dyn(input_data.clone(), (CH_IN, DIM_IN)); + + let mut out1 = model.forward(inp1).retrieve(); + cx.compile(crate::CudaCompiler::::default(), &mut out1); + cx.execute(); + + let input = + candle_core::Tensor::from_vec(input_data, (1, CH_IN, DIM_IN), &candle_core::Device::Cpu) + .unwrap(); + let kernel = candle_core::Tensor::from_vec( + kernel_data, + (CH_OUT, CH_IN, KERNEL), + &candle_core::Device::Cpu, + ) + .unwrap(); + let output = input.conv1d(&kernel, PADDING, STRIDE, 1, 1).unwrap(); + + assert_close_precision( + &out1.data(), + &output.flatten_all().unwrap().to_vec1::().unwrap(), + 1e-2, + ); +} diff --git a/crates/luminal_cuda/src/tests/fp32.rs b/crates/luminal_cuda/src/tests/fp32.rs index 271c14b3..7d72d960 100644 --- a/crates/luminal_cuda/src/tests/fp32.rs +++ b/crates/luminal_cuda/src/tests/fp32.rs @@ -1,34 +1,51 @@ use dfdx::prelude::{Module as DfdxModule, *}; -use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; use luminal::{module::Module, prelude::*}; - -#[allow(unused_imports)] -use dfdx::prelude::{ - Axes as DAxes, Axes2 as DAxes2, Axes3 as DAxes3, Axes4 as DAxes4, Axes5 as DAxes5, - Axis as DAxis, Const as DConst, *, -}; -#[allow(unused_imports)] -use luminal::{ - prelude::{ - Axes as LAxes, Axes2 as LAxes2, Axes3 as LAxes3, Axes4 as LAxes4, Axes5 as LAxes5, - Axis as LAxis, Const as LConst, *, - }, - tests::{ - assert_close, assert_close_precision, assert_exact, random_vec, random_vec_rng, test_graphs, - }, -}; - -use crate::{single_unary_test, CudaCompiler}; -single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32, 3); // For some reason ln fails on larger tensors +use luminal_nn::{Conv1D, Linear, ReLU}; + +use crate::{binary_test, unary_test, CudaCompiler}; +luminal::test_imports!(); + +unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f32); +unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f32); +unary_test!(|a| a.recip(), |a| a.recip(), test_recip, f32); +unary_test!(|a| a * a, |a| a.clone() * a, test_square, f32); +unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32); +unary_test!(|a| a.log2(), |a| a.ln() / 2_f32.ln(), test_log2, f32); +unary_test!(|a| a.exp2(), |a| (a * 2_f32.ln()).exp(), test_exp2, f32); +unary_test!( + |a| a.softmax(0), + |a| a.softmax::>(), + test_softmax, + f32 +); +unary_test!( + |a| a.mean_norm(0).std_norm(0, 1e-5), + |a| a.normalize::>(1e-5), + test_norm, + f32 +); + +binary_test!(|a, b| a + b, |a, b| a + b, test_add, f32); +binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f32); +binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f32); +binary_test!(|a, b| a / b, |a, b| a / b, test_div, f32); +binary_test!( + |a, b| a % b, + |a, b| a.clone() - ((a / b.clone()).to_dtype::().to_dtype::() * b), + test_mod, + f32 +); +binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f32); +binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f32); #[test] fn test_contiguous() { let mut cx = Graph::new(); let data = random_vec(12); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.permute::, _>().reshape::>().retrieve(); + let a = cx.tensor((3, 4)).set(data.clone()); + let mut b = a.permute((1, 0)).reshape((12, 1)).retrieve(); cx.compile(CudaCompiler::::default(), &mut b); cx.execute(); @@ -39,172 +56,51 @@ fn test_contiguous() { assert_close(&b.data(), &d_b.as_vec()); } -#[test] -fn test_rotate() { - let mut cx = Graph::new(); - const D: usize = 2; - const S: usize = 2; - const H: usize = 2; - let data = random_vec(D * S * H); - let a = cx - .tensor::>() - .set(data) - .keep() - .permute::<_, LAxes4<0, 2, 1, 3>>(); - let x1 = a.slice((.., .., .., ..Expression::from(H / 2))); - let x2 = a.slice((.., .., .., Expression::from(H / 2)..)); - let mut rotated_a = (-x2) - .concat_along::, LAxis<3>, _>(x1) - .retrieve(); - cx.execute(); - let unopt = rotated_a.data(); - - cx.compile(CudaCompiler::::default(), &mut rotated_a); - cx.execute(); - - assert_close(&unopt, &rotated_a.data()); -} - -#[test] -fn test_constant() { - let mut cx = Graph::new(); - let a = cx.constant_expr('a'); - let mut a = (a * a).retrieve(); - cx.compile(CudaCompiler::::default(), &mut a); - - cx.set_dyn_dim('a', 10); - cx.execute(); - assert_exact(&a.data(), &[100.0]); - a.drop(); - cx.set_dyn_dim('a', 25); - cx.execute(); - assert_exact(&a.data(), &[625.0]); -} - -#[test] -fn test_log2() { - let mut cx = Graph::new(); - let data = random_vec(3); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.log2().retrieve(); - - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - assert_close( - &b.data(), - &data.into_iter().map(|i| i.log2()).collect::>(), - ); -} - -#[test] -fn test_exp2() { - let mut cx = Graph::new(); - let data = random_vec(3); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.exp2().retrieve(); - - cx.compile(CudaCompiler::::default(), &mut b); - cx.execute(); - - assert_close( - &b.data(), - &data.into_iter().map(|i: f32| i.exp2()).collect::>(), - ); -} - -#[test] -fn test_mod() { - let mut cx = Graph::new(); - let a_data = random_vec(3); - let b_data = random_vec(3); - let a = cx.tensor::>().set(a_data.clone()); - let b = cx.tensor::>().set(b_data.clone()); - let mut c = a % b; - c.retrieve(); - - cx.compile(CudaCompiler::::default(), &mut c); - cx.execute(); - - // No dfdx equivalent - - assert_close( - &c.data(), - &a_data - .into_iter() - .zip(b_data) - .map(|(a, b)| a % b) - .collect_vec(), - ); -} - // Reduction op tests #[test] fn test_sum_reduce() { - let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.sum_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.sum_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.sum_reduce::<_, LAxis<0>>().retrieve(); + let data = random_vec(4 * 4096); + let a = cx.tensor((1, 4, 4096)); + a.set(data.clone()); + let mut b = a.sum_reduce(1).retrieve(); + let mut c = a.sum_reduce(0).retrieve(); + let mut d = a.sum_reduce(2).retrieve(); cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>)); - let d_b = d_a.clone().sum::<_, DAxis<2>>(); - let d_c = d_a.clone().sum::<_, DAxis<1>>(); - let d_d = d_a.sum::<_, DAxis<0>>(); + let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<4>, DConst::<4096>)); + let d_b = d_a.clone().sum::<_, DAxis<1>>(); + let d_c = d_a.clone().sum::<_, DAxis<0>>(); + let d_d = d_a.sum::<_, DAxis<2>>(); + assert_close(&b.data(), &d_b.as_vec()); assert_close(&c.data(), &d_c.as_vec()); assert_close(&d.data(), &d_d.as_vec()); } -#[test] -fn test_sum_reduce2() { - let mut cx = Graph::new(); - let data = random_vec(32 * 10 * 10 * 128); - let a = cx.tensor::>().set(data.clone()); - let mut d = a.sum_reduce::<_, LAxis<2>>().retrieve(); - - cx.compile(CudaCompiler::::default(), &mut d); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec( - data, - ( - DConst::<1>, - DConst::<32>, - DConst::<10>, - DConst::<10>, - DConst::<128>, - ), - ); - let d_d = d_a.sum::<_, DAxis<2>>(); - - assert_exact(&d.data(), &d_d.as_vec()); -} - #[test] fn test_max_reduce() { - let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.max_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.max_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.max_reduce::<_, LAxis<0>>().retrieve(); + let data = random_vec(12); + let a = cx.tensor((2, 2, 3)); + a.set(data.clone()); + let mut b = a.max_reduce(1).retrieve(); + let mut c = a.max_reduce(0).retrieve(); + let mut d = a.max_reduce(2).retrieve(); cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>)); - let d_b = d_a.clone().max::<_, DAxis<2>>(); - let d_c = d_a.clone().max::<_, DAxis<1>>(); - let d_d = d_a.max::<_, DAxis<0>>(); + let d_a = d_dev.tensor_from_vec(data, (DConst::<2>, DConst::<2>, DConst::<3>)); + let d_b = d_a.clone().max::<_, DAxis<1>>(); + let d_c = d_a.clone().max::<_, DAxis<0>>(); + let d_d = d_a.max::<_, DAxis<2>>(); + assert_close(&b.data(), &d_b.as_vec()); assert_close(&c.data(), &d_c.as_vec()); assert_close(&d.data(), &d_d.as_vec()); @@ -214,22 +110,16 @@ fn test_max_reduce() { fn test_mean_reduce() { let data = random_vec(40960); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut b = a.mean_reduce::<_, LAxis<2>>().retrieve(); - let mut c = a.mean_reduce::<_, LAxis<1>>().retrieve(); - let mut d = a.mean_reduce::<_, LAxis<0>>().retrieve(); + let a = cx.tensor((1, 10, 4096)).set(data.clone()); + let mut b = a.mean_reduce(2).retrieve(); - cx.compile(CudaCompiler::::default(), (&mut b, &mut c, &mut d)); + cx.compile(CudaCompiler::::default(), &mut b); cx.execute(); let d_dev = Cpu::default(); let d_a = d_dev.tensor_from_vec(data, (DConst::<1>, DConst::<10>, DConst::<4096>)); - let d_b = d_a.clone().mean::<_, DAxis<2>>(); - let d_c = d_a.clone().mean::<_, DAxis<1>>(); - let d_d = d_a.mean::<_, DAxis<0>>(); + let d_b = d_a.mean::<_, DAxis<2>>(); assert_close(&b.data(), &d_b.as_vec()); - assert_close(&c.data(), &d_c.as_vec()); - assert_close(&d.data(), &d_d.as_vec()); } #[test] @@ -237,8 +127,8 @@ fn test_matmul_simple() { let mut cx = Graph::new(); let a_data = random_vec(256 * 256); let b_data = random_vec(256 * 256); - let a = cx.tensor::>().set(a_data.clone()); - let b = cx.tensor::>().set(b_data.clone()); + let a = cx.tensor((256, 256)).set(a_data.clone()); + let b = cx.tensor((256, 256)).set(b_data.clone()); let mut c = a.matmul(b).retrieve(); cx.compile(CudaCompiler::::default(), &mut c); @@ -254,175 +144,70 @@ fn test_matmul_simple() { #[test] fn test_matmul() { - let d_dev = Cpu::default(); let mut cx = Graph::new(); - let mut rng = StdRng::seed_from_u64(0); - let a = cx.tensor::<(Dyn<'M'>, Dyn<'K'>)>(); - let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>(); - let mut c = a.matmul(b).retrieve(); - - cx.compile(CudaCompiler::::default(), &mut c); - for m in (1..23).step_by(4) { - for k in (1..35).step_by(3) { - for n in (1..70).step_by(7) { - let a_data = random_vec_rng(m * k, &mut rng); - let b_data = random_vec_rng(k * n, &mut rng); - a.set_dyn(a_data.clone(), &[m, k]); - b.set_dyn(b_data.clone(), &[k, n]); - cx.execute(); - - let d_a = d_dev.tensor_from_vec(a_data, (m, k)); - let d_b = d_dev.tensor_from_vec(b_data, (k, n)); - let d_c = d_a.matmul(d_b); - - assert_close(&c.data(), &d_c.as_vec()); - c.drop(); - } - } - } -} - -#[test] -fn test_attn_matmul() { - let mut cx = Graph::new(); - let mut rng = StdRng::seed_from_u64(0); - let a_data = random_vec_rng(32 * 11 * 128, &mut rng); - let b_data = random_vec_rng(32 * 11 * 128, &mut rng); - let a = cx - .named_tensor::>("Input") - .set(a_data.clone()) - .keep(); - let b = cx - .named_tensor::>("Input") - .set(b_data.clone()) - .keep(); + let a_data = random_vec(512 * 512); + let b_data = random_vec(512 * 512); + let a = cx.tensor((512, 512)).set(a_data.clone()); + let b = cx.tensor((512, 512)).set(b_data.clone()); let mut c = a.matmul(b).retrieve(); cx.compile(CudaCompiler::::default(), &mut c); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec( - a_data, - (DConst::<1>, DConst::<32>, DConst::<11>, DConst::<128>), - ); - let d_b = d_dev.tensor_from_vec( - b_data, - (DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>), - ); + let d_a = d_dev.tensor_from_vec(a_data, (DConst::<512>, DConst::<512>)); + let d_b = d_dev.tensor_from_vec(b_data, (DConst::<512>, DConst::<512>)); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); + + assert_close(&c.data(), &d_c.as_vec()); } #[test] fn test_batch_matmul() { - let m = 12; let mut cx = Graph::new(); - let mut rng = StdRng::seed_from_u64(0); - let a = cx.tensor::<(Dyn<'B'>, Dyn<'M'>, Dyn<'K'>)>(); - let b = cx.tensor::<(Dyn<'K'>, Dyn<'N'>)>(); + let a = cx + .tensor((2, 2, 3)) + .set(vec![1., 2., 3., 1., 2., 1., 1., 2., 3., 1., 2., 1.]); + let b = cx + .tensor((3, 4)) + .set(vec![1., 2., 3., 1., 1., 2., 1., 2., -1., -2., 1., 2.]); let mut c = a.matmul(b).retrieve(); cx.compile(CudaCompiler::::default(), &mut c); - for batch in (1..23).step_by(4) { - for k in (1..35).step_by(3) { - for n in (1..48).step_by(7) { - let a_data = random_vec_rng(batch * m * k, &mut rng); - let b_data = random_vec_rng(k * n, &mut rng); - a.set_dyn(a_data.clone(), &[batch, m, k]); - b.set_dyn(b_data.clone(), &[k, n]); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(a_data, (batch, m, k)); - let d_b = d_dev.tensor_from_vec(b_data, (k, n)); - let d_c = d_a.matmul(d_b); - - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.01); - c.drop(); - } - } - } -} - -#[test] -fn test_batch_matmul_transpose() { - const B: usize = 1; - const M: usize = 48; // Any - const K: usize = 4096; // >= 16, multiple of 16 - const N: usize = 4096; // >= 256, multiple of 256 - let mut cx = Graph::new(); - let mut rng = StdRng::seed_from_u64(0); - - let a_data = random_vec_rng(B * M * K, &mut rng); - let a = cx.named_tensor::>("A").set(a_data.clone()); - let b_data = random_vec_rng(K * N, &mut rng); - let b = cx.named_tensor::>("B").set(b_data.clone()); - let a_t_data = random_vec_rng(B * K * M, &mut rng); - let a_t = cx.named_tensor::>("A_T").set(a_t_data.clone()); - let b_t_data = random_vec_rng(K * N, &mut rng); - let b_t = cx.named_tensor::>("B_T").set(b_t_data.clone()); - - let mut a_b = a.matmul(b.permute::<_, LAxes2<1, 0>>()).retrieve(); - let mut a_b_t = a.matmul(b_t).retrieve(); - let mut a_t_b = a_t - .permute::<_, LAxes3<0, 2, 1>>() - .matmul(b.permute::<_, LAxes2<1, 0>>()) - .retrieve(); - let mut a_t_b_t = a_t.permute::<_, LAxes3<0, 2, 1>>().matmul(b_t).retrieve(); - - cx.compile( - <(GenericCompiler, CudaCompiler)>::default(), - (&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t), - ); cx.execute(); let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(a_data, (DConst::, DConst::, DConst::)); - let d_b = d_dev.tensor_from_vec(b_data, (DConst::, DConst::)); - let d_a_t = d_dev.tensor_from_vec(a_t_data, (DConst::, DConst::, DConst::)); - let d_b_t = d_dev.tensor_from_vec(b_t_data, (DConst::, DConst::)); - let d_a_b = d_a.clone().matmul(d_b.clone().permute::<_, DAxes2<1, 0>>()); - let d_a_b_t = d_a.matmul(d_b_t.clone()); - let d_a_t_b = d_a_t - .clone() - .permute::<_, DAxes3<0, 2, 1>>() - .matmul(d_b.permute::<_, DAxes2<1, 0>>()); - let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t); - - assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 0.1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 0.1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 0.1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 0.1); + let d_a = d_dev.tensor([[[1., 2., 3.], [1., 2., 1.]], [[1., 2., 3.], [1., 2., 1.]]]); + let d_b = d_dev.tensor([[1., 2., 3., 1.], [1., 2., 1., 2.], [-1., -2., 1., 2.]]); + let d_c = d_a.matmul(d_b); + + assert_close(&c.data(), &d_c.as_vec()); } #[test] fn test_matmul_transpose() { const M: usize = 1024; // Any const K: usize = 16; // >= 16 - const N: usize = 767; // >= 256, multiple of 256 + const N: usize = 256; // >= 256, power of 2 let mut cx = Graph::new(); let mut rng = StdRng::seed_from_u64(0); let a_data = random_vec_rng(M * K, &mut rng); - let a = cx.tensor::>().set(a_data.clone()); + let a = cx.tensor((M, K)).set(a_data.clone()); let b_data = random_vec_rng(K * N, &mut rng); - let b = cx.tensor::>().set(b_data.clone()); + let b = cx.tensor((N, K)).set(b_data.clone()); let a_t_data = random_vec_rng(K * M, &mut rng); - let a_t = cx.tensor::>().set(a_t_data.clone()); + let a_t = cx.tensor((K, M)).set(a_t_data.clone()); let b_t_data = random_vec_rng(K * N, &mut rng); - let b_t = cx.tensor::>().set(b_t_data.clone()); + let b_t = cx.tensor((K, N)).set(b_t_data.clone()); - let mut a_b = a.matmul(b.permute()).retrieve(); + let mut a_b = a.matmul(b.permute((1, 0))).retrieve(); let mut a_b_t = a.matmul(b_t).retrieve(); - let mut a_t_b = a_t - .permute::<_, LAxes2<1, 0>>() - .matmul(b.permute()) - .retrieve(); - let mut a_t_b_t = a_t.permute::<_, LAxes2<1, 0>>().matmul(b_t).retrieve(); + let mut a_t_b = a_t.permute((1, 0)).matmul(b.permute((1, 0))).retrieve(); + let mut a_t_b_t = a_t.permute((1, 0)).matmul(b_t).retrieve(); cx.compile( - <(GenericCompiler, CudaCompiler)>::default(), + CudaCompiler::::default(), (&mut a_b, &mut a_b_t, &mut a_t_b, &mut a_t_b_t), ); cx.execute(); @@ -436,9 +221,11 @@ fn test_matmul_transpose() { let d_a_b_t = d_a.matmul(d_b_t.clone()); let d_a_t_b = d_a_t .clone() - .permute::<_, DAxes2<1, 0>>() + .permute::<_, dfdx::shapes::Axes2<1, 0>>() .matmul(d_b.permute()); - let d_a_t_b_t = d_a_t.permute::<_, DAxes2<1, 0>>().matmul(d_b_t); + let d_a_t_b_t = d_a_t + .permute::<_, dfdx::shapes::Axes2<1, 0>>() + .matmul(d_b_t); assert_close(&a_b.data(), &d_a_b.as_vec()); assert_close(&a_b_t.data(), &d_a_b_t.as_vec()); @@ -453,16 +240,14 @@ fn test_relu_and_linear() { let input_data = random_vec(32); let w1 = random_vec(32 * 64); let w2 = random_vec(32 * 64); - let batch = cx - .named_tensor::>("Batch") - .set(random_vec(32 * 2)); - let a = cx.named_tensor::>("Single").set(input_data.clone()); - - let model: ( - luminal_nn::Linear<32, 64>, - luminal_nn::ReLU, - luminal_nn::Linear<64, 32>, - ) = InitModule::initialize(&mut cx); + let batch = cx.named_tensor("Batch", (2, 32)).set(random_vec(32 * 2)); + let a = cx.named_tensor("Single", 32).set(input_data.clone()); + + let model = ( + Linear::new(32, 64, false, &mut cx), + ReLU, + Linear::new(64, 32, false, &mut cx), + ); model.0.weight.set(w1.clone()); model.2.weight.set(w2.clone()); let mut b = model.forward(a).retrieve(); @@ -479,8 +264,8 @@ fn test_relu_and_linear() { ); cx.execute(); - assert_close_precision(&unoptimized_b, &b.data(), 0.01); - assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 0.01); + assert_close_precision(&unoptimized_b, &b.data(), 1e-2); + assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 1e-2); // Test against dfdx let dev = Cpu::default(); @@ -491,245 +276,188 @@ fn test_relu_and_linear() { )>::build_on_device(&dev); // Set weights model.0.weight = dev - .tensor_from_vec(w1, (dfdx::shapes::Const::<32>, dfdx::shapes::Const::<64>)) + .tensor_from_vec(w1, (DConst::<32>, DConst::<64>)) .permute(); model.2.weight = dev - .tensor_from_vec(w2, (dfdx::shapes::Const::<64>, dfdx::shapes::Const::<32>)) + .tensor_from_vec(w2, (DConst::<64>, DConst::<32>)) .permute(); - let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,)); + let a = dev.tensor_from_vec(input_data, (DConst::<32>,)); let out = model.forward(a); - assert_close_precision(&unoptimized_b, &out.as_vec(), 0.01); -} - -#[test] -fn test_rms_norm() { - // Test single and batch, unoptimized and optimized - let inp_data = random_vec(15 * 32); - let weight_data = random_vec(32); - let mut cx = Graph::new(); - let a = cx.tensor::>().set(inp_data.clone()); - - let model = luminal_nn::LayerNorm::<32>::new(true, false, false, 1e-5, &mut cx); - model.weight.unwrap().set(weight_data.clone()); - let mut b = model.forward(a).retrieve(); - - cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut b); - cx.execute(); - - // Test against dfdx - let dev = Cpu::default(); - let weight = dev.tensor_from_vec(weight_data, (DConst::<32>,)); - let a = dev.tensor_from_vec(inp_data, (DConst::<15>, DConst::<32>)); - let var_f32 = a.clone().square().mean::<_, DAxis<1>>(); - let std_f32 = (var_f32 + 1e-6).sqrt(); - let x_f32 = a / std_f32.broadcast(); - let out = weight.broadcast() * x_f32; - - assert_close(&b.data(), &out.as_vec()); + assert_close_precision(&unoptimized_b, &out.as_vec(), 1e-2); } #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); - let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx); - let w_k_weight = random_vec(32 * 32); - model.attention.w_k.weight.set(w_k_weight.clone()); - let w_q_weight = random_vec(32 * 32); - model.attention.w_q.weight.set(w_q_weight.clone()); - let w_v_weight = random_vec(32 * 32); - model.attention.w_v.weight.set(w_v_weight.clone()); - let w_o_weight = random_vec(32 * 32); - model.attention.w_o.weight.set(w_o_weight.clone()); - let ff_0_weight = random_vec(32 * 64); - model.ff.0.weight.set(ff_0_weight.clone()); - let ff_1_weight = random_vec(64 * 32); - model.ff.2.weight.set(ff_1_weight.clone()); - - let a_data = random_vec(2 * 32); + let model = luminal_nn::TransformerEncoderBlock::new(3, 4, 1, &mut cx); + model + .attention + .w_k + .weight + .set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]); + model + .attention + .w_q + .weight + .set(vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.]); + model + .attention + .w_v + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.]); + model + .attention + .w_o + .weight + .set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]); + model + .ff + .0 + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.]); + model + .ff + .2 + .weight + .set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]); + let a = cx - .tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>() - .set_dyn(a_data.clone(), &[1, 2, 3]) - .keep(); - cx.keep_tensors(params(&model)); + .tensor(('a', 3)) + .set_dyn(vec![-1., 2., 3., 3., 3., -1.], (2, 3)); let mut b = model.forward(a).retrieve(); - cx.execute(); - let unopt_b = b.data(); - b.drop(); cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut b); cx.execute(); - assert_close_precision(&unopt_b, &b.data(), 0.01); let d_dev = Cpu::default(); - let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> = - d_dev - .build_module::, f32>(); - d_model.self_attn.w_k.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_v.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_q.bias.copy_from(&[0.; 32]); - d_model.self_attn.w_o.bias.copy_from(&[0.; 32]); + let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<3, 1, 4, f32, Cpu> = + d_dev.build_module::, f32>(); + d_model.self_attn.w_k.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_v.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_q.bias.copy_from(&[0.0, 0.0, 0.0]); + d_model.self_attn.w_o.bias.copy_from(&[0., 0., 0.]); d_model.self_attn.w_o.weight = d_dev - .tensor_from_vec(w_o_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![1., 22., 3., 1., 2., 3., 1., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_k.weight = d_dev - .tensor_from_vec(w_k_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![1., 22., 3., 1., 2., 3., 1., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_q.weight = d_dev - .tensor_from_vec(w_q_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![3., 2., 3., 1.3, 2., 3., 3., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.self_attn.w_v.weight = d_dev - .tensor_from_vec(w_v_weight, (DConst::<32>, DConst::<32>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3.], + (DConst::<3>, DConst::<3>), + ) .permute(); d_model.ff.0 .0.weight = d_dev - .tensor_from_vec(ff_0_weight, (DConst::<32>, DConst::<64>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 11., 2., 3.], + (DConst::<3>, DConst::<4>), + ) .permute(); - d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0.; 64], (DConst::<64>,)); + d_model.ff.0 .0.bias = d_dev.tensor_from_vec(vec![0., 0., 0., 0.], (DConst::<4>,)); d_model.ff.0 .2.weight = d_dev - .tensor_from_vec(ff_1_weight, (DConst::<64>, DConst::<32>)) + .tensor_from_vec( + vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.], + (DConst::<4>, DConst::<3>), + ) .permute(); - d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); - d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,)); - d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1.; 32], (DConst::<32>,)); + d_model.ff.0 .2.bias = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); + d_model.norm1.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,)); + d_model.norm2.gamma = d_dev.tensor_from_vec(vec![1., 1., 1.], (DConst::<3>,)); d_model.norm1.epsilon = 1e-5; - d_model.norm2.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); - d_model.norm1.beta = d_dev.tensor_from_vec(vec![0.; 32], (DConst::<32>,)); + d_model.norm2.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); + d_model.norm1.beta = d_dev.tensor_from_vec(vec![0., 0., 0.], (DConst::<3>,)); d_model.norm2.epsilon = 1e-5; - let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>)); + let d_a = d_dev.tensor_from_vec(vec![-1., 2., 3., 3., 3., -1.], (DConst::<2>, DConst::<3>)); let d_b = d_model.forward(d_a); - assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); + assert_close(&b.data(), &d_b.as_vec()); } #[test] -fn test_embedding() { +fn test_pool_1d_dims() { let mut cx = Graph::new(); - let batch = cx - .named_tensor::>("Batch") - .set(vec![1.0, 0.0, 2.0, 1.0, 0.0, 1.0]) - .keep(); - let a = cx - .named_tensor::>("Single") - .set(vec![1.0, 0.0, 1.0]) - .keep(); - let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx); - model - .weight - .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]); - let mut b = model.forward(a).retrieve(); - let mut batch_out = model.forward(batch).retrieve(); + let inp1 = cx.tensor((4, 4)).set(vec![ + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., + ]); + // Stride 1 + let out1 = inp1.pool_last_dim(3, 1, 1).retrieve(); - cx.compile(CudaCompiler::::default(), (&mut b, &mut batch_out)); cx.execute(); - let d_dev = Cpu::default(); - let mut d_model: modules::Embedding<3, 4, f32, Cpu> = - >::build_on_device(&d_dev); - d_model.weight = d_dev.tensor_from_vec( - vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.], - (DConst::<3>, DConst::<4>), + assert_exact( + &out1.data(), + &[ + 1., 2., 3., 2., 3., 4., 5., 6., 7., 6., 7., 8., 9., 10., 11., 10., 11., 12., 13., 14., + 15., 14., 15., 16., + ], ); - let d_a = d_dev.tensor_from_vec(vec![1, 0, 1], (DConst::<3>,)); - let d_batch = d_dev.tensor_from_vec(vec![1, 0, 2, 1, 0, 1], (DConst::<2>, DConst::<3>)); - - let d_b = d_model.forward(d_a); - let d_batch_out = d_model.forward(d_batch); - - assert_close(&b.data(), &d_b.as_vec()); - assert_close(&batch_out.data(), &d_batch_out.as_vec()); -} - -#[test] -fn test_slice() { - let data = random_vec(256); - let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut c: GraphTensor> = a - .slice((..Expression::from(20),)) - .realize() - .contiguous() - .retrieve(); - - cx.compile(CudaCompiler::::default(), &mut c); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<256>,)); - let d_c = d_a.slice((..20,)); - - assert_exact(&c.data(), &d_c.as_vec()); } #[test] -fn test_pad() { - // Pad a 8x2 mat to 10x4 - let data = random_vec(8 * 2); +fn test_pool_2d() { let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let mut c = a.pad::>(((0, 2), (0, 2))).contiguous().retrieve(); - cx.compile(CudaCompiler::::default(), &mut c); - cx.execute(); - - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (8, 2)); - // There is no pad function in dfdx, so we concat with zero tensors - let d_b = (d_a, d_dev.zeros_like(&(2, 2))).concat_along(DAxis::<0>); - let d_c = (d_b, d_dev.zeros_like(&(10, 2))).concat_along(DAxis::<1>); - - assert_exact(&c.data(), &d_c.as_vec()); -} - -#[test] -fn test_pad_contig() { - let m = 13; - let k = 24; - let mut cx = Graph::new(); - let mut rng = StdRng::seed_from_u64(0); - let a_data = random_vec_rng(m * k, &mut rng); - let mut a = cx - .tensor::<(Dyn<'M'>, Dyn<'K'>)>() - .set_dyn(a_data, &[m, k]) - .retrieve(); - let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a - .pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')]) - .contiguous() + let inp1 = cx.tensor((4, 4)).set(vec![ + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., + ]); + // 3x3 kernel + let out1 = inp1 + // Pool first dim first by moving it to end + .permute((1, 0)) + .pool_last_dim(3, 1, 1) + // Now move other dim to end + .permute((1, 2, 0)) + .pool_last_dim(3, 1, 1) + // Now swap middle two dims + .permute((0, 2, 1, 3)) + // Now merge both pooled dimensions + .reshape((4, 3, 3)) .retrieve(); - let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = - (a.slice((.., ..Expression::from(k))).realize() / 1.0).retrieve(); - cx.compile( - <(GenericCompiler, CudaCompiler)>::default(), - (&mut a, &mut b, &mut c), - ); cx.execute(); - // Close because b and c are going through 16 bits, while a is not - assert_close(&a.data(), &b.data()); - assert_close(&a.data(), &c.data()); + assert_exact( + &out1.data(), + &[ + 1.00, 2.00, 3.00, 5.00, 6.00, 7.00, 9.00, 10.00, 11.00, 2.00, 3.00, 4.00, 6.00, 7.00, + 8.00, 10.00, 11.00, 12.00, 5.00, 6.00, 7.00, 9.00, 10.00, 11.00, 13.00, 14.00, 15.00, + 6.00, 7.00, 8.00, 10.00, 11.00, 12.00, 14.00, 15.00, 16.00, + ], + ); } #[test] -fn test_movement() { - let data = random_vec(32); +fn test_pool_1d_dilation() { let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let b: GraphTensor> = a.pad(&[(0, 10)]).contiguous().retrieve(); - let mut c: GraphTensor> = b - .slice((..Expression::from(25),)) - .realize() - .contiguous() - .retrieve(); - cx.compile(CudaCompiler::::default(), &mut c); - cx.execute(); + let inp1 = cx.tensor(5).set(vec![1., 2., 3., 4., 5.]); + // Stride 1 + let out1 = inp1.pool_last_dim(2, 1, 2).retrieve(); + // Stride 2 + let out2 = inp1.pool_last_dim(2, 2, 2).retrieve(); + // Stride 3 + let out3 = inp1.pool_last_dim(2, 3, 2).retrieve(); - let d_dev = Cpu::default(); - let d_a = d_dev.tensor_from_vec(data, (DConst::<32>,)); - let d_c = d_a.slice((..25,)); + cx.execute(); - assert_exact(&c.data(), &d_c.as_vec()); + assert_exact(&out1.data(), &[1., 3., 2., 4., 3., 5.]); + assert_exact(&out2.data(), &[1., 3., 3., 5.]); + assert_exact(&out3.data(), &[1., 3.]); } #[test] @@ -742,14 +470,12 @@ fn test_conv2d() { const KERNELY: usize = 2; const STRIDEX: usize = KERNELX; const STRIDEY: usize = KERNELY; - const DILATIONX: usize = 0; - const DILATIONY: usize = 0; + const DILATIONX: usize = 1; + const DILATIONY: usize = 1; const DIMX_IN: usize = 16; - const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1; const DIMY_IN: usize = 9; - const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1; - let inp1 = cx.tensor::>().set(vec![ + let inp1 = cx.tensor((CH_IN, DIMX_IN, DIMY_IN)).set(vec![ 8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8., 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., 4., 7., 1., 0., 7., 7., 4., 9., 5., 0., 4., 7., 4., 7., 8., 8., 4., 8., 4., 7., 9., 3., 7., 9., 5., @@ -784,7 +510,15 @@ fn test_conv2d() { 1., 2., 1., 1., 4., 7., 2., ]); - let model = luminal_nn::Conv2D::::initialize(&mut cx); + let model = luminal_nn::Conv2D::new( + CH_IN, + CH_OUT, + (KERNELX, KERNELY), + (STRIDEX, STRIDEY), + (DILATIONX, DILATIONY), + false, + &mut cx, + ); model.weight.set(vec![ 0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300, 0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, 0.1500, @@ -792,9 +526,7 @@ fn test_conv2d() { 0.0100, -0.2000, 0.2100, -0.0400, -0.1400, 0.1500, 0.0500, -0.1700, 0.1400, ]); - let mut out1 = model - .forward::(inp1) - .retrieve(); + let mut out1 = model.forward(inp1).retrieve(); cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut out1); cx.execute(); @@ -812,3 +544,45 @@ fn test_conv2d() { ], ); } + +#[test] +fn test_conv1d_pad_stride() { + let mut cx = Graph::new(); + let mut rng = StdRng::seed_from_u64(0); + + const CH_IN: usize = 80; + const CH_OUT: usize = 384; + const KERNEL: usize = 3; + const STRIDE: usize = 1; + const PADDING: usize = 1; + const DIM_IN: usize = 10; + let kernel_data = random_vec_rng(KERNEL * CH_IN * CH_OUT, &mut rng); + let input_data = random_vec_rng(CH_IN * DIM_IN, &mut rng); + + let model = Conv1D::new(CH_IN, CH_OUT, KERNEL, STRIDE, 1, PADDING, false, &mut cx); + model.weight.set(kernel_data.clone()); + + let inp1 = cx + .tensor((1, CH_IN, 's')) + .set_dyn(input_data.clone(), (1, CH_IN, DIM_IN)); + + let mut out1 = model.forward(inp1).retrieve(); + cx.compile(crate::CudaCompiler::::default(), &mut out1); + cx.execute(); + + let input = + candle_core::Tensor::from_vec(input_data, (1, CH_IN, DIM_IN), &candle_core::Device::Cpu) + .unwrap(); + let kernel = candle_core::Tensor::from_vec( + kernel_data, + (CH_OUT, CH_IN, KERNEL), + &candle_core::Device::Cpu, + ) + .unwrap(); + let output = input.conv1d(&kernel, PADDING, STRIDE, 1, 1).unwrap(); + + assert_close( + &out1.data(), + &output.flatten_all().unwrap().to_vec1::().unwrap(), + ); +} diff --git a/crates/luminal_cuda/src/tests/mod.rs b/crates/luminal_cuda/src/tests/mod.rs index 6c36075d..9f7b7be6 100644 --- a/crates/luminal_cuda/src/tests/mod.rs +++ b/crates/luminal_cuda/src/tests/mod.rs @@ -1,7 +1,4 @@ -use dfdx::prelude::*; -use luminal::prelude::*; -use luminal::tests::random_vec_rng; -use rand::{rngs::StdRng, SeedableRng}; +use luminal::{graph::Graph, op::Operator}; mod fp16; mod fp32; @@ -11,14 +8,14 @@ macro_rules! single_unary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty, $size: expr) => { paste::paste! { #[test] - fn [<$name _ $type _ $size>]() { + fn [<$name _ $size>]() { let mut rng = StdRng::seed_from_u64(1); let data = random_vec_rng($size, &mut rng); let mut cx = Graph::new(); - let a = cx.tensor::>().set(data.clone()); - let f: fn(GraphTensor>) -> GraphTensor> = $luminal_func; + let a = cx.tensor($size).set(data.clone()); + let f: fn(GraphTensor) -> GraphTensor = $luminal_func; let mut b = f(a).retrieve(); - cx.compile($crate::CudaCompiler::<$type>::default(), &mut b); + cx.compile(CudaCompiler::<$type>::default(), &mut b); cx.execute(); let d_dev = Cpu::default(); @@ -30,14 +27,14 @@ macro_rules! single_unary_test { ) -> dfdx::prelude::Tensor, $type, Cpu, NoneTape> = $dfdx_func; let d_b = f(d_a); - luminal::tests::assert_close(&b.data(), &d_b.to_dtype::().as_vec()); + assert_close(&b.data(), &d_b.to_dtype::().as_vec()); } } }; } #[macro_export] -macro_rules! unary_test_type { +macro_rules! unary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty) => { $crate::single_unary_test!($luminal_func, $dfdx_func, $name, $type, 3); $crate::single_unary_test!($luminal_func, $dfdx_func, $name, $type, 50); @@ -46,30 +43,22 @@ macro_rules! unary_test_type { }; } -#[macro_export] -macro_rules! unary_test { - ($luminal_func: expr , $dfdx_func: expr , $name: ident) => { - $crate::unary_test_type!($luminal_func, $dfdx_func, $name, f32); - $crate::unary_test_type!($luminal_func, $dfdx_func, $name, f16); - }; -} - #[macro_export] macro_rules! single_binary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty, $size: expr) => { paste::paste! { #[test] - fn [<$name _ $type _ $size>]() { + fn [<$name _ $size>]() { let mut rng = StdRng::seed_from_u64(2); let a_data = random_vec_rng($size, &mut rng); let b_data = random_vec_rng($size, &mut rng); let mut cx = Graph::new(); - let a = cx.tensor::>().set(a_data.clone()); - let b = cx.tensor::>().set(b_data.clone()); - let f: fn(GraphTensor>, GraphTensor>) -> GraphTensor> = + let a = cx.tensor($size).set(a_data.clone()); + let b = cx.tensor($size).set(b_data.clone()); + let f: fn(GraphTensor, GraphTensor) -> GraphTensor = $luminal_func; let mut c = f(a, b).retrieve(); - cx.compile($crate::CudaCompiler::<$type>::default(), &mut c); + cx.compile(CudaCompiler::<$type>::default(), &mut c); cx.execute(); let d_dev = Cpu::default(); @@ -85,14 +74,14 @@ macro_rules! single_binary_test { ) -> dfdx::prelude::Tensor, $type, Cpu, NoneTape> = $dfdx_func; let d_c = f(d_a, d_b); - luminal::tests::assert_close(&c.data(), &d_c.to_dtype::().as_vec()); + assert_close(&c.data(), &d_c.to_dtype::().as_vec()); } } }; } #[macro_export] -macro_rules! binary_test_type { +macro_rules! binary_test { ($luminal_func: expr , $dfdx_func: expr , $name: ident, $type: ty) => { $crate::single_binary_test!($luminal_func, $dfdx_func, $name, $type, 3); $crate::single_binary_test!($luminal_func, $dfdx_func, $name, $type, 50); @@ -101,50 +90,9 @@ macro_rules! binary_test_type { }; } -#[macro_export] -macro_rules! binary_test { - ($luminal_func: expr , $dfdx_func: expr , $name: ident) => { - $crate::binary_test_type!($luminal_func, $dfdx_func, $name, f32); - $crate::binary_test_type!($luminal_func, $dfdx_func, $name, f16); - }; -} - pub fn assert_op_in_graph(graph: &Graph) { assert!( graph.node_indices().any(|i| graph.check_node_type::(i)), "Node not found in the graph!" ); } - -unary_test!(|a| a.sin(), |a| a.sin(), test_sin); -unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt); -unary_test!(|a| a.recip(), |a| a.recip(), test_recip); -unary_test!(|a| a * a, |a| a.clone() * a, test_square); -unary_test!(|a| a.exp(), |a| a.exp(), test_exp); -unary_test!(|a| a.cos(), |a| a.cos(), test_cos); -unary_test!(|a| a.softmax(), |a| a.softmax(), test_softmax); -unary_test!( - |a| a.mean_norm::>(), - |a| a.clone() - a.mean::<_, dfdx::prelude::Axis<0>>().broadcast(), - test_mean_norm -); -unary_test!( - |a| a.std_norm::, _>(1e-5), - |a| a.clone() / a.stddev::<_, dfdx::prelude::Axis<0>>(1e-5).broadcast(), - test_std_norm -); -unary_test!( - |a| a.layer_norm::, _>(1e-5), - |a| a.normalize::>(1e-5), - test_norm -); - -binary_test!(|a, b| a + b, |a, b| a + b, test_add); -binary_test!(|a, b| a - b, |a, b| a - b, test_sub); -binary_test!(|a, b| a * b, |a, b| a * b, test_mul); -binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div); -binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max); -binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min); - -single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f16, 3); // For some reason ln fails on larger tensors -single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f32, 3); // For some reason ln fails on larger tensors diff --git a/crates/luminal_cuda/src/unary.rs b/crates/luminal_cuda/src/unary.rs index 0b45f547..eb7af78c 100644 --- a/crates/luminal_cuda/src/unary.rs +++ b/crates/luminal_cuda/src/unary.rs @@ -89,19 +89,19 @@ impl Operator for CudaMeanReduce { let out = self.device.alloc_zeros::(inp_size).unwrap(); let front_size = tensors[0] .1 - .shape() + .dims() .iter() .take(self.dim) .map(|i| i.to_usize().unwrap()) .product::() as i32; let back_size = tensors[0] .1 - .shape() + .dims() .iter() .skip(self.dim + 1) .map(|i| i.to_usize().unwrap()) .product::() as i32; - let dim_size = tensors[0].1.shape()[self.dim].to_usize().unwrap() as i32; + let dim_size = tensors[0].1.dims()[self.dim].to_usize().unwrap() as i32; let mut params = vec![ get_buffer_from_tensor::(&tensors[0].0).as_kernel_param(), (&out).as_kernel_param(), @@ -266,7 +266,7 @@ extern \"C\" __global__ void kernel(const {type_name} * src0, {type_name} * dst impl Operator for CudaStdNorm { fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let row_size = tensors[0].1.shape().last().unwrap().to_usize().unwrap(); + let row_size = tensors[0].1.dims().last().unwrap().to_usize().unwrap(); let row_size_int = row_size as i32; let out = self .device @@ -280,7 +280,7 @@ impl Operator for CudaStdNorm { ]; let batch_size = tensors[0] .1 - .shape() + .dims() .into_iter() .take(tensors[0].1.len() - 1) .map(|i| i.to_usize().unwrap()) @@ -354,7 +354,7 @@ impl Compiler for StdNormCompiler { } } if sh - .shape() + .dims() .last() .unwrap() .to_usize() @@ -569,13 +569,13 @@ impl Operator for CudaSoftmax { let inp_size = tensors[0].1.n_elements().to_usize().unwrap(); let batch_size = tensors[0] .1 - .shape() + .dims() .iter() .take(tensors[0].1.len() - 1) .map(|i| i.to_usize().unwrap()) .product::() .max(1); - let axis_size = tensors[0].1.shape().last().unwrap().to_usize().unwrap(); + let axis_size = tensors[0].1.dims().last().unwrap().to_usize().unwrap(); let axis_size_int = axis_size as i32; let out = self.device.alloc_zeros::(inp_size).unwrap(); @@ -655,8 +655,8 @@ mod tests { #[test] fn test_norms() { let mut cx = Graph::new(); - let a = cx.tensor().set([0.; 32]); - let mut b = a.layer_norm::, _>(1e-5).retrieve(); + let a = cx.tensor(32).set([0.; 32]); + let mut b = a.layer_norm(0, 1e-5).retrieve(); cx.compile( <( diff --git a/examples/llama/src/main.rs b/examples/llama/src/main.rs index 364f5fe3..f62fa8cc 100644 --- a/examples/llama/src/main.rs +++ b/examples/llama/src/main.rs @@ -79,7 +79,10 @@ fn main() { luminal_metal::BufferCompilers::default(), ), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ( + luminal_cuda::CudaCompiler::::default(), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), diff --git a/examples/llama_server/src/llama/setup.rs b/examples/llama_server/src/llama/setup.rs index a4d61d30..fd6f99f8 100644 --- a/examples/llama_server/src/llama/setup.rs +++ b/examples/llama_server/src/llama/setup.rs @@ -86,7 +86,10 @@ impl Model { luminal_metal::BufferCompilers::default(), ), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ( + luminal_cuda::CudaCompiler::::default(), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), diff --git a/examples/phi/src/main.rs b/examples/phi/src/main.rs index 1cff55e5..f7b2bd0c 100644 --- a/examples/phi/src/main.rs +++ b/examples/phi/src/main.rs @@ -67,6 +67,9 @@ fn main() { print!("Compiling graph"); io::stdout().flush().unwrap(); + if debug() { + println!(); + } let now = Instant::now(); cx.compile( ( @@ -78,7 +81,10 @@ fn main() { luminal_metal::BufferCompilers::default(), ), #[cfg(feature = "cuda")] - luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ( + luminal_cuda::CudaCompiler::::default(), + luminal_cuda::CudaQuantizedCompiler::::new(q_weights), + ), #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] luminal_cpu::CPUCompiler::default(), ), diff --git a/src/compiler_utils.rs b/src/compiler_utils.rs index 8388dfcf..7d9e06c7 100644 --- a/src/compiler_utils.rs +++ b/src/compiler_utils.rs @@ -949,3 +949,12 @@ pub fn unary(node: SelectGraph) -> SelectGraph { pub fn binary(a: SelectGraph, b: SelectGraph) -> SelectGraph { b.connect(a.connect(op::())) } + +/// Whether or not to do debug prints (env var DEBUG=1) +pub fn debug() -> bool { + std::env::var("DEBUG") + .unwrap_or_default() + .parse::() + .map(|i| i == 1) + .unwrap_or_default() +} diff --git a/src/graph_tensor.rs b/src/graph_tensor.rs index 34f3fcc2..b02364f5 100644 --- a/src/graph_tensor.rs +++ b/src/graph_tensor.rs @@ -82,8 +82,13 @@ impl GraphTensor { /// Get the contiguous data of the tensor pub fn data(&self) -> Vec { - let tensor = self.graph().get_tensor_ref(self.id, 0).unwrap(); - let orig_data = tensor.downcast_ref::>().unwrap(); + let tensor = self + .graph() + .get_tensor_ref(self.id, 0) + .expect("Tensor not found in the graph!"); + let orig_data = tensor + .downcast_ref::>() + .expect("Data for tensor is not Vec!"); let mut st = self.shape; if !st.is_reshaped() { return orig_data.clone(); diff --git a/src/hl_ops/reduction.rs b/src/hl_ops/reduction.rs index 95c903e1..4738f8f5 100644 --- a/src/hl_ops/reduction.rs +++ b/src/hl_ops/reduction.rs @@ -4,6 +4,7 @@ use crate::{ }; impl GraphTensor { + /// Reduce a dimension of the tensor by summing all elements along that axis. pub fn sum_reduce(self, axes: impl ToAxes) -> GraphTensor { let mut shape = self.shape; @@ -20,6 +21,7 @@ impl GraphTensor { GraphTensor::from_id(new_id, shape, self.graph_ref) } + /// Reduce a dimension of the tensor by taking the maximum of all elements along that axis. pub fn max_reduce(self, axes: impl ToAxes) -> GraphTensor { let mut shape = self.shape; @@ -36,6 +38,7 @@ impl GraphTensor { GraphTensor::from_id(new_id, shape, self.graph_ref) } + /// Reduce a dimension of the tensor by taking the mean of all elements along that axis. pub fn mean_reduce(self, axes: impl ToAxes) -> GraphTensor { let mut shape = self.shape; let mut node_id = self.id; @@ -63,6 +66,11 @@ impl GraphTensor { } GraphTensor::from_id(node_id, shape, self.graph_ref) } + + /// Reduce a dimension of the tensor by multiplying all elements along that axis. + pub fn prod_reduce(self, axes: impl ToAxes) -> GraphTensor { + self.ln().sum_reduce(axes).exp() + } } #[cfg(test)] diff --git a/src/shape/symbolic.rs b/src/shape/symbolic.rs index 0eb7a99f..02f8b118 100644 --- a/src/shape/symbolic.rs +++ b/src/shape/symbolic.rs @@ -951,6 +951,7 @@ mod tests { .unwrap(), 768 ); + expression_cleanup(); } #[test] @@ -958,6 +959,7 @@ mod tests { let expr = ((Expression::from('a') * 1) + 0) / 1 + (1 - 1); let reduced_expr = expr.simplify(); assert_eq!(reduced_expr, 'a'); + expression_cleanup(); } #[test] @@ -966,6 +968,7 @@ mod tests { let sub = Expression::from('x') / 2; let new = main.substitute('x', sub).simplify(); assert_eq!(new, (Expression::from('x') / 2) + -255); + expression_cleanup(); } #[test] @@ -973,5 +976,6 @@ mod tests { let s = Expression::from('s'); let expr = (s * ((s - 4) + 1)) + (((s + 1) * ((s - 4) + 1)) - (s * ((s - 4) + 1))); assert_eq!(expr.simplify().terms.read().len(), 7); + expression_cleanup(); } } diff --git a/src/shape/tracker.rs b/src/shape/tracker.rs index 30893d22..a45f7cfb 100644 --- a/src/shape/tracker.rs +++ b/src/shape/tracker.rs @@ -370,6 +370,7 @@ mod tests { println!("Strides: {:?}", tracker.strides()); println!("Ind: {:?}", tracker.index_expression()); println!("Val: {:?}", tracker.valid_expression()); + expression_cleanup(); } #[test]