From 1a43b7485ee93f7399570a2b8a095758ad76e416 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 19 Apr 2024 22:59:33 -0500 Subject: [PATCH] Updated cuda --- crates/luminal_cuda/Cargo.toml | 4 +- crates/luminal_cuda/src/binary.rs | 25 +++----- crates/luminal_cuda/src/lib.rs | 37 ++++++----- crates/luminal_cuda/src/matmul.rs | 3 +- crates/luminal_cuda/src/other.rs | 11 ++-- crates/luminal_cuda/src/prim.rs | 88 +++++++-------------------- crates/luminal_cuda/src/quantized.rs | 17 ++---- crates/luminal_cuda/src/tests/fp16.rs | 72 +++++++++++----------- crates/luminal_cuda/src/tests/fp32.rs | 52 ++++++++-------- 9 files changed, 125 insertions(+), 184 deletions(-) diff --git a/crates/luminal_cuda/Cargo.toml b/crates/luminal_cuda/Cargo.toml index da746d33..fe53041e 100644 --- a/crates/luminal_cuda/Cargo.toml +++ b/crates/luminal_cuda/Cargo.toml @@ -16,8 +16,10 @@ luminal_cudarc = { version="0.10.0", features = [ itertools = "0.12.1" rustc-hash = "1.1.0" num-traits = "0.2.18" +fmt-derive = "0.1.1" [dev-dependencies] dfdx = { version = "0.13", features = ["f16"] } rand = "0.8.5" -paste = "1.0.14" \ No newline at end of file +paste = "1.0.14" +luminal_nn = {path="../../crates/luminal_nn"} diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index 9ab6d056..a47e2aad 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -1,5 +1,6 @@ use std::{marker::PhantomData, sync::Arc}; +use fmt_derive::Debug; use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; use luminal::{ @@ -15,7 +16,7 @@ use crate::{ render_dyn_dim_inputs, CudaData, CudaFloat, }; -#[derive(LuminalEqTrue, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaSub { function: CudaFunction, device: Arc, @@ -81,7 +82,7 @@ impl Operator for CudaSub { } } -#[derive(LuminalPrint, Default)] +#[derive(Debug, Default)] pub struct SubtractionCompiler(PhantomData); impl Compiler for SubtractionCompiler { @@ -139,7 +140,7 @@ impl Compiler for SubtractionCompiler { } } -#[derive(LuminalEqTrue, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaEqual { function: CudaFunction, device: Arc, @@ -205,7 +206,7 @@ impl Operator for CudaEqual { } } -#[derive(LuminalPrint, Default)] +#[derive(Debug, Default)] pub struct EqualCompiler(PhantomData); impl Compiler for EqualCompiler { @@ -262,7 +263,7 @@ impl Compiler for EqualCompiler { } } -#[derive(LuminalPrint, Clone, LuminalEqFalse)] +#[derive(Clone, Debug)] pub struct CudaGather { function: CudaFunction, device: Arc, @@ -294,13 +295,7 @@ extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *weights impl Operator for CudaGather { fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec { // Inp 1 should be Vec and inp 2 should be a CudaSlice - let indexes = inputs[0] - .0 - .borrowed() - .data - .as_any() - .downcast_ref::>() - .unwrap(); + let indexes = inputs[0].0.borrowed().downcast_ref::>().unwrap(); let weights = get_buffer_from_tensor::(&inputs[1].0); let mut indexes_buffer = unsafe { self.device.alloc::(indexes.len()).unwrap() }; @@ -335,13 +330,11 @@ impl Operator for CudaGather { .unwrap(); } - vec![Tensor { - data: Box::new(CudaData(out)), - }] + vec![Tensor::new(CudaData(out))] } } -#[derive(LuminalPrint, Default)] +#[derive(Debug, Default)] pub struct GatherCompiler(PhantomData); impl Compiler for GatherCompiler { diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index ec1a010d..d1b83d8e 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -20,8 +20,6 @@ use std::{collections::hash_map::DefaultHasher, ffi::c_void, fmt::Write, hash::H use luminal::{op::InputTensor, prelude::*}; -use self::symbolic::{BigExpression, Term}; - pub type CudaCompiler = ( prim::PrimitiveCompiler, binary::SubtractionCompiler, @@ -80,16 +78,6 @@ impl Data for CudaData { } } -impl Data for CudaData { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - impl CudaFloat for f16 { fn from_f32(a: f32) -> Self { f16::from_f32(a) @@ -105,6 +93,21 @@ impl CudaFloat for f16 { } } +impl CudaFloat for u8 { + fn from_f32(a: f32) -> Self { + a as u8 + } + fn to_f32(self) -> f32 { + self as f32 + } + fn is_f32() -> bool { + false + } + fn type_name() -> &'static str { + "uint8_t" + } +} + fn expr_to_cuda_string(expr: BigExpression) -> String { let mut symbols = vec![]; for term in expr.terms { @@ -195,14 +198,8 @@ fn hash(obj: T) -> u64 { hasher.finish() } -fn get_buffer_from_tensor<'a, T: 'static>(tensor: &'a InputTensor) -> &'a CudaSlice { - &tensor - .borrowed() - .data - .as_any() - .downcast_ref::>() - .unwrap() - .0 +fn get_buffer_from_tensor<'a, T: CudaFloat>(tensor: &'a InputTensor) -> &'a CudaSlice { + &tensor.borrowed().downcast_ref::>().unwrap().0 } fn input_dyn_dims( diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index 4e93e160..c550e792 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -1,5 +1,6 @@ use std::{marker::PhantomData, sync::Arc}; +use fmt_derive::Debug; use luminal_cudarc::{ cublas::{sys::cublasOperation_t::*, CudaBlas}, driver::{CudaDevice, DevicePtr, DevicePtrMut}, @@ -15,7 +16,7 @@ use luminal::{ prelude::*, }; -#[derive(LuminalPrint, LuminalEqFalse, Clone)] +#[derive(Clone, Debug)] pub struct Matmul(Arc, Arc, PhantomData); impl Operator for Matmul { diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index f656cfea..bb71feb1 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -2,7 +2,8 @@ use std::{marker::PhantomData, sync::Arc}; use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig}; -use luminal::{op::*, prelude::*, shape::symbolic::BigExpression}; +use fmt_derive::Debug; +use luminal::prelude::*; use rustc_hash::FxHashMap; use crate::{ @@ -12,7 +13,7 @@ use crate::{ CudaData, CudaFloat, }; -#[derive(LuminalPrint, Clone, LuminalEqFalse)] +#[derive(Clone, Debug)] pub struct CudaARange { function: CudaFunction, device: Arc, @@ -65,13 +66,11 @@ impl Operator for CudaARange { .unwrap(); } - vec![Tensor { - data: Box::new(CudaData(out)), - }] + vec![Tensor::new(CudaData(out))] } } -#[derive(LuminalPrint, Default)] +#[derive(Debug, Default)] pub struct ARangeCompiler(PhantomData); impl Compiler for ARangeCompiler { diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index 3fe6b820..7a0426d3 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -1,12 +1,12 @@ use crate::{compile_and_load_kernel, get_buffer_from_tensor, input_dyn_dims, CudaData, CudaFloat}; use super::{get_idx_valid_exps, render_dyn_dim_inputs}; +use fmt_derive::Debug; use itertools::Itertools; use rustc_hash::FxHashMap; use std::{ any::{Any, TypeId}, - fmt::Debug, marker::PhantomData, sync::Arc, }; @@ -19,7 +19,7 @@ use luminal::{ }; /// Copy a tensor to the GPU -#[derive(Clone, LuminalEqFalse, LuminalPrint)] +#[derive(Clone, Debug)] pub struct CudaCopyToDevice(Arc, PhantomData); impl CudaCopyToDevice { @@ -30,19 +30,11 @@ impl CudaCopyToDevice { impl Operator for CudaCopyToDevice { fn process(&mut self, mut inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { - if inp[0].0.borrowed().data.as_any().is::>() - || inp[0].0.borrowed().data.as_any().is::>() - { + if inp[0].0.borrowed().is::>() || inp[0].0.borrowed().is::>() { // Already on device return vec![inp.pop().unwrap().0.cloned()]; } - let cpu_data = inp[0] - .0 - .borrowed() - .data - .as_any() - .downcast_ref::>() - .unwrap(); + let cpu_data = inp[0].0.borrowed().downcast_ref::>().unwrap(); let vec = cpu_data .iter() .copied() @@ -53,7 +45,7 @@ impl Operator for CudaCopyToDevice { } /// Copy a tensor from the GPU -#[derive(Clone, LuminalEqFalse, LuminalPrint)] +#[derive(Clone, Debug)] pub struct CudaCopyFromDevice(Arc, PhantomData); impl CudaCopyFromDevice { @@ -64,7 +56,7 @@ impl CudaCopyFromDevice { impl Operator for CudaCopyFromDevice { fn process(&mut self, mut inp: Vec<(InputTensor, ShapeTracker)>) -> Vec { - if inp[0].0.borrowed().data.as_any().is::>() { + if inp[0].0.borrowed().is::>() { // Already off device return vec![inp.pop().unwrap().0.cloned()]; } @@ -79,14 +71,14 @@ impl Operator for CudaCopyFromDevice { } /// Constant value on device -#[derive(Clone, LuminalEqFalse)] +#[derive(Clone)] pub struct CudaConstant { pub value: ConstantValue, device: Arc, dyn_map: *const FxHashMap, _phantom: PhantomData, } -impl Debug for CudaConstant { +impl core::fmt::Debug for CudaConstant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "CudaConstant({:?})", self.value) } @@ -121,7 +113,7 @@ impl Operator for CudaConstant { } } -#[derive(LuminalPrint, Clone, LuminalEqFalse)] +#[derive(Clone, Debug)] pub struct CudaContiguous { function: CudaFunction, device: Arc, @@ -180,7 +172,7 @@ impl Operator for CudaContiguous { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaLog2 { function: CudaFunction, device: Arc, @@ -235,7 +227,7 @@ impl Operator for CudaLog2 { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaExp2 { function: CudaFunction, device: Arc, @@ -289,7 +281,7 @@ impl Operator for CudaExp2 { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaSqrt { function: CudaFunction, device: Arc, @@ -347,7 +339,7 @@ impl Operator for CudaSqrt { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaSin { function: CudaFunction, device: Arc, @@ -402,7 +394,7 @@ impl Operator for CudaSin { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaRecip { function: CudaFunction, device: Arc, @@ -461,7 +453,7 @@ impl Operator for CudaRecip { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaAdd { function: CudaFunction, device: Arc, @@ -534,7 +526,7 @@ impl Operator for CudaAdd { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaMul { function: CudaFunction, device: Arc, @@ -604,7 +596,7 @@ impl Operator for CudaMul { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaMod { function: CudaFunction, device: Arc, @@ -674,7 +666,7 @@ impl Operator for CudaMod { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaLessThan { function: CudaFunction, device: Arc, @@ -750,7 +742,7 @@ impl Operator for CudaLessThan { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaSumReduce { function: CudaFunction, pub device: Arc, @@ -843,7 +835,7 @@ where } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct CudaMaxReduce { function: CudaFunction, pub device: Arc, @@ -933,7 +925,7 @@ impl Operator for CudaMaxReduce { } /// Convert all primitive ops to cuda primitive ops, and insert copy to and from device ops -#[derive(LuminalPrint, Default)] +#[derive(Debug, Default)] pub struct PrimitiveCompiler(PhantomData); impl Compiler for PrimitiveCompiler { @@ -1021,44 +1013,6 @@ impl Compiler for PrimitiveCompiler { } } - // Copy prints from device - for (output_node, edge) in graph - .node_indices() - // Filter non-functions - .filter(|n| graph.node_weight(*n).unwrap().as_any().is::()) - .map(|n| { - ( - n, - graph - .edges_directed(n, petgraph::Direction::Incoming) - .find(|e| !e.weight().is_schedule()) - .unwrap() - .id(), - ) - }) - .collect::>() - { - // Create copy node - let (source, shape) = ( - graph.edge_endpoints(edge).unwrap().0, - graph.edge_weight(edge).unwrap().as_data().unwrap().2, - ); - let copy_node = graph - .add_op(CudaCopyFromDevice::::new(dev.clone())) - .input(source, 0, shape) - .finish(); - graph.add_edge( - copy_node, - output_node, - Dependency::Data { - shape, - input_order: 0, - output_order: 0, - }, - ); - graph.remove_edge(edge); - } - fn is(type_id: TypeId) -> bool { type_id == TypeId::of::() } diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index 714b9ab1..08d30982 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -1,5 +1,6 @@ use std::{marker::PhantomData, sync::Arc}; +use fmt_derive::Debug; use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig}; use petgraph::visit::EdgeRef; @@ -13,7 +14,7 @@ use crate::{ }; /// Multiplies a BxMxK matrix with a KxN matrix, resulting in a BxMxN matrix. This expects the first input to be a quantized 2D matrix -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct QuantizedMatmul { matvec_function: CudaFunction, device: Arc, @@ -169,7 +170,7 @@ impl Operator for QuantizedMatmul { } } -#[derive(LuminalEqFalse, LuminalPrint, Clone)] +#[derive(Clone, Debug)] pub struct QuantizedGather { pipeline: CudaFunction, device: Arc, @@ -203,13 +204,7 @@ extern \"C\" __global__ void kernel(const float* inp, const block_q8_0* weights, impl Operator for QuantizedGather { fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { // Setup buffers - let indexes = tensors[0] - .0 - .borrowed() - .data - .as_any() - .downcast_ref::>() - .unwrap(); + let indexes = tensors[0].0.borrowed().downcast_ref::>().unwrap(); let mut index_buffer = unsafe { self.device.alloc::(indexes.len()).unwrap() }; self.device .htod_copy_into(indexes.clone(), &mut index_buffer) @@ -249,7 +244,7 @@ impl Operator for QuantizedGather { } } -#[derive(Default)] +#[derive(Default, Debug)] pub struct CudaQuantizedCompiler(Vec, PhantomData); impl CudaQuantizedCompiler { @@ -467,7 +462,7 @@ mod tests { (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), ); let d_c = d_b.matmul(d_a.permute()); - assert_close_precision(&out.data(), &d_c.as_vec(), 0); + assert_close_precision(&out.data(), &d_c.as_vec(), 1.0); // This is imprecise currently because we accumulate in fp16 in the matmul. TODO: accumulate in fp32 and convert before saving to dest blocks.leak(); // Segfaults without this diff --git a/crates/luminal_cuda/src/tests/fp16.rs b/crates/luminal_cuda/src/tests/fp16.rs index 832f15b4..50b24c8c 100644 --- a/crates/luminal_cuda/src/tests/fp16.rs +++ b/crates/luminal_cuda/src/tests/fp16.rs @@ -3,10 +3,7 @@ use itertools::Itertools; use num_traits::Float; use rand::{rngs::StdRng, SeedableRng}; -use luminal::{ - nn::{activation::ReLU, linear::Linear, norm::RMSNorm}, - prelude::{symbolic::Expression, Module, *}, -}; +use luminal::{module::Module, prelude::*}; #[allow(unused_imports)] use dfdx::prelude::{ @@ -63,7 +60,7 @@ fn test_softmax() { let mut cx = Graph::new(); let data = random_vec(12); let a = cx.tensor::>().set(data.clone()); - let mut b = a.softmax::<1>().retrieve(); + let mut b = a.softmax::>().retrieve(); cx.compile(CudaCompiler::::default(), &mut b); cx.execute(); @@ -200,17 +197,17 @@ fn test_sum_reduce() { assert_close_precision( &b.data(), &d_b.to_dtype::().to_dtype::().as_vec(), - 1, + 0.1, ); assert_close_precision( &c.data(), &d_c.to_dtype::().to_dtype::().as_vec(), - 1, + 0.1, ); assert_close_precision( &d.data(), &d_d.to_dtype::().to_dtype::().as_vec(), - 1, + 0.1, ); } @@ -240,7 +237,7 @@ fn test_sum_reduce2() { assert_close_precision( &d.data(), &d_d.to_dtype::().to_dtype::().as_vec(), - 1, + 0.1, ); } @@ -309,7 +306,7 @@ fn test_matmul_simple() { let d_b = d_dev.tensor_from_vec(b_data, (DConst::<256>, DConst::<256>)); let d_c = d_a.to_dtype::().matmul(d_b.to_dtype::()); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1); // Why is this imprecise? + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1.); // Why is this imprecise? } #[test] @@ -335,7 +332,7 @@ fn test_matmul() { let d_b = d_dev.tensor_from_vec(b_data, (k, n)); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); c.drop(); } } @@ -375,7 +372,7 @@ fn test_attn_matmul() { ) .to_dtype::(); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); } #[test] @@ -403,7 +400,7 @@ fn test_batch_matmul() { let d_b = d_dev.tensor_from_vec(b_data, (k, n)); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 1); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.1); c.drop(); } } @@ -455,10 +452,10 @@ fn test_batch_matmul_transpose() { .matmul(d_b.permute::<_, DAxes2<1, 0>>()); let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t); - assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 1); + assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 0.1); + assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 0.1); + assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 0.1); + assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 0.1); } #[test] @@ -513,10 +510,10 @@ fn test_matmul_transpose() { .matmul(d_b.permute()); let d_a_t_b_t = d_a_t.permute::<_, DAxes2<1, 0>>().matmul(d_b_t); - assert_close_precision(&a_b.data(), &d_a_b.to_dtype::().as_vec(), 1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.to_dtype::().as_vec(), 1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.to_dtype::().as_vec(), 1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.to_dtype::().as_vec(), 1); + assert_close_precision(&a_b.data(), &d_a_b.to_dtype::().as_vec(), 0.1); + assert_close_precision(&a_b_t.data(), &d_a_b_t.to_dtype::().as_vec(), 0.1); + assert_close_precision(&a_t_b.data(), &d_a_t_b.to_dtype::().as_vec(), 0.1); + assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.to_dtype::().as_vec(), 0.1); } #[test] @@ -531,7 +528,11 @@ fn test_relu_and_linear() { .set(random_vec(32 * 2)); let a = cx.named_tensor::>("Single").set(input_data.clone()); - let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx); + let model: ( + luminal_nn::Linear<32, 64>, + luminal_nn::ReLU, + luminal_nn::Linear<64, 32>, + ) = InitModule::initialize(&mut cx); model.0.weight.set(w1.clone()); model.2.weight.set(w2.clone()); let mut b = model.forward(a).retrieve(); @@ -548,8 +549,8 @@ fn test_relu_and_linear() { ); cx.execute(); - assert_close_precision(&unoptimized_b, &b.data(), 2); - assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2); + assert_close_precision(&unoptimized_b, &b.data(), 0.01); + assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 0.01); // Test against dfdx let dev = Cpu::default(); @@ -572,7 +573,7 @@ fn test_relu_and_linear() { .to_dtype::(); let out = model.forward(a); - assert_close_precision(&unoptimized_b, &out.to_dtype::().as_vec(), 2); + assert_close_precision(&unoptimized_b, &out.to_dtype::().as_vec(), 0.01); } #[test] @@ -584,7 +585,7 @@ fn test_rms_norm() { let mut cx = Graph::new(); let a = cx.tensor::>().set(inp_data.clone()); - let model = RMSNorm::<32>::initialize(&mut cx); + let model = luminal_nn::RMSNorm::<32>::initialize(&mut cx); model.weight.set(weight_data.clone()); let mut b = model.forward(a).retrieve(); @@ -612,8 +613,8 @@ fn test_layer_norm() { let mut cx = Graph::new(); let a_data = random_vec(15 * 16 * 32); let a = cx.tensor::>().set(a_data.clone()); - let mut b = a.layer_norm::<0, _>(1e-5).retrieve(); - let mut c = a.layer_norm::<2, _>(1e-5).retrieve(); + let mut b = a.layer_norm::, _>(1e-5).retrieve(); + let mut c = a.layer_norm::, _>(1e-5).retrieve(); cx.compile( <(GenericCompiler, CudaCompiler)>::default(), (&mut b, &mut c), @@ -625,15 +626,14 @@ fn test_layer_norm() { let d_b = d_a.clone().normalize::>(1e-5); let d_c = d_a.normalize::>(1e-5); - assert_close_precision(&b.data(), &d_b.as_vec(), 2); - assert_close_precision(&c.data(), &d_c.as_vec(), 2); + assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); + assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); } #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); - let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<32, 64, 1> = - InitModule::initialize(&mut cx); + let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx); let w_k_weight = random_vec(32 * 32); model.attention.w_k.weight.set(w_k_weight.clone()); let w_q_weight = random_vec(32 * 32); @@ -652,7 +652,7 @@ fn test_transformer_encoder_block() { .tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>() .set_dyn(a_data.clone(), &[1, 2, 3]) .keep(); - cx.keep_tensors(state_dict(&model)); + cx.keep_tensors(params(&model)); let mut b = model.forward(a).retrieve(); cx.execute(); let unopt_b = b.data(); @@ -660,7 +660,7 @@ fn test_transformer_encoder_block() { cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut b); cx.execute(); - assert_close_precision(&unopt_b, &b.data(), 2); + assert_close_precision(&unopt_b, &b.data(), 0.01); let d_dev = Cpu::default(); let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> = @@ -699,7 +699,7 @@ fn test_transformer_encoder_block() { let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>)); let d_b = d_model.forward(d_a); - assert_close_precision(&b.data(), &d_b.as_vec(), 2); + assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); } #[test] @@ -730,7 +730,7 @@ fn test_embedding() { .set(vec![1.0, 0.0, 1.0]) .keep(); - let model: luminal::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx); + let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx); model .weight .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]); diff --git a/crates/luminal_cuda/src/tests/fp32.rs b/crates/luminal_cuda/src/tests/fp32.rs index ba521592..78aab5b8 100644 --- a/crates/luminal_cuda/src/tests/fp32.rs +++ b/crates/luminal_cuda/src/tests/fp32.rs @@ -2,10 +2,7 @@ use dfdx::prelude::{Module as DfdxModule, *}; use itertools::Itertools; use rand::{rngs::StdRng, SeedableRng}; -use luminal::{ - nn::{activation::ReLU, linear::Linear, norm::RMSNorm}, - prelude::{symbolic::Expression, Module, *}, -}; +use luminal::{module::Module, prelude::*}; #[allow(unused_imports)] use dfdx::prelude::{ @@ -60,7 +57,7 @@ fn test_softmax() { let mut cx = Graph::new(); let data = random_vec(12); let a = cx.tensor::>().set(data.clone()); - let mut b = a.softmax::<1>().retrieve(); + let mut b = a.softmax::>().retrieve(); cx.compile(CudaCompiler::::default(), &mut b); cx.execute(); @@ -343,7 +340,7 @@ fn test_attn_matmul() { (DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>), ); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.as_vec(), 2); + assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); } #[test] @@ -370,7 +367,7 @@ fn test_batch_matmul() { let d_b = d_dev.tensor_from_vec(b_data, (k, n)); let d_c = d_a.matmul(d_b); - assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 2); + assert_close_precision(&c.data(), &d_c.to_dtype::().as_vec(), 0.01); c.drop(); } } @@ -422,10 +419,10 @@ fn test_batch_matmul_transpose() { .matmul(d_b.permute::<_, DAxes2<1, 0>>()); let d_a_t_b_t = d_a_t.permute::<_, DAxes3<0, 2, 1>>().matmul(d_b_t); - assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 1); - assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 1); - assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 1); - assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 1); + assert_close_precision(&a_b.data(), &d_a_b.as_vec(), 0.1); + assert_close_precision(&a_b_t.data(), &d_a_b_t.as_vec(), 0.1); + assert_close_precision(&a_t_b.data(), &d_a_t_b.as_vec(), 0.1); + assert_close_precision(&a_t_b_t.data(), &d_a_t_b_t.as_vec(), 0.1); } #[test] @@ -490,7 +487,11 @@ fn test_relu_and_linear() { .set(random_vec(32 * 2)); let a = cx.named_tensor::>("Single").set(input_data.clone()); - let model: (Linear<32, 64>, ReLU, Linear<64, 32>) = InitModule::initialize(&mut cx); + let model: ( + luminal_nn::Linear<32, 64>, + luminal_nn::ReLU, + luminal_nn::Linear<64, 32>, + ) = InitModule::initialize(&mut cx); model.0.weight.set(w1.clone()); model.2.weight.set(w2.clone()); let mut b = model.forward(a).retrieve(); @@ -507,8 +508,8 @@ fn test_relu_and_linear() { ); cx.execute(); - assert_close_precision(&unoptimized_b, &b.data(), 2); - assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 2); + assert_close_precision(&unoptimized_b, &b.data(), 0.01); + assert_close_precision(&unoptimized_batch_out, &batch_out.data(), 0.01); // Test against dfdx let dev = Cpu::default(); @@ -527,7 +528,7 @@ fn test_relu_and_linear() { let a = dev.tensor_from_vec(input_data, (dfdx::shapes::Const::<32>,)); let out = model.forward(a); - assert_close_precision(&unoptimized_b, &out.as_vec(), 2); + assert_close_precision(&unoptimized_b, &out.as_vec(), 0.01); } #[test] @@ -538,7 +539,7 @@ fn test_rms_norm() { let mut cx = Graph::new(); let a = cx.tensor::>().set(inp_data.clone()); - let model = RMSNorm::<32>::initialize(&mut cx); + let model = luminal_nn::RMSNorm::<32>::initialize(&mut cx); model.weight.set(weight_data.clone()); let mut b = model.forward(a).retrieve(); @@ -562,8 +563,8 @@ fn test_layer_norm() { let mut cx = Graph::new(); let a_data = random_vec(15 * 16 * 32); let a = cx.tensor::>().set(a_data.clone()); - let mut b = a.layer_norm::<0, _>(1e-5).retrieve(); - let mut c = a.layer_norm::<2, _>(1e-5).retrieve(); + let mut b = a.layer_norm::, _>(1e-5).retrieve(); + let mut c = a.layer_norm::, _>(1e-5).retrieve(); cx.compile( <(GenericCompiler, CudaCompiler)>::default(), (&mut b, &mut c), @@ -575,15 +576,14 @@ fn test_layer_norm() { let d_b = d_a.clone().normalize::>(1e-5); let d_c = d_a.normalize::>(1e-5); - assert_close_precision(&b.data(), &d_b.as_vec(), 2); - assert_close_precision(&c.data(), &d_c.as_vec(), 2); + assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); + assert_close_precision(&c.data(), &d_c.as_vec(), 0.01); } #[test] fn test_transformer_encoder_block() { let mut cx = Graph::new(); - let model: luminal::nn::transformer::encoder::TransformerEncoderBlock<32, 64, 1> = - InitModule::initialize(&mut cx); + let model: luminal_nn::TransformerEncoderBlock<32, 64, 1> = InitModule::initialize(&mut cx); let w_k_weight = random_vec(32 * 32); model.attention.w_k.weight.set(w_k_weight.clone()); let w_q_weight = random_vec(32 * 32); @@ -602,7 +602,7 @@ fn test_transformer_encoder_block() { .tensor::<(Dyn<'b'>, Dyn<'a'>, LConst<32>)>() .set_dyn(a_data.clone(), &[1, 2, 3]) .keep(); - cx.keep_tensors(state_dict(&model)); + cx.keep_tensors(params(&model)); let mut b = model.forward(a).retrieve(); cx.execute(); let unopt_b = b.data(); @@ -610,7 +610,7 @@ fn test_transformer_encoder_block() { cx.compile(<(GenericCompiler, CudaCompiler)>::default(), &mut b); cx.execute(); - assert_close_precision(&unopt_b, &b.data(), 2); + assert_close_precision(&unopt_b, &b.data(), 0.01); let d_dev = Cpu::default(); let mut d_model: dfdx::nn::modules::TransformerEncoderBlock<32, 1, 64, f32, Cpu> = @@ -649,7 +649,7 @@ fn test_transformer_encoder_block() { let d_a = d_dev.tensor_from_vec(a_data, (DConst::<2>, DConst::<32>)); let d_b = d_model.forward(d_a); - assert_close_precision(&b.data(), &d_b.as_vec(), 2); + assert_close_precision(&b.data(), &d_b.as_vec(), 0.01); } #[test] @@ -664,7 +664,7 @@ fn test_embedding() { .set(vec![1.0, 0.0, 1.0]) .keep(); - let model: luminal::nn::embedding::Embedding<3, 4> = InitModule::initialize(&mut cx); + let model: luminal_nn::Embedding<3, 4> = InitModule::initialize(&mut cx); model .weight .set(vec![1.1, 2., 3., 1., 2., 3., 14., 2., 33., 1., 2., 3.]);