diff --git a/crates/luminal_cuda/src/lib.rs b/crates/luminal_cuda/src/lib.rs index a58b8d23..4b21af8c 100644 --- a/crates/luminal_cuda/src/lib.rs +++ b/crates/luminal_cuda/src/lib.rs @@ -4,6 +4,7 @@ mod matmul; mod other; mod prim; mod quantized; +#[macro_use] mod unary; pub use quantized::*; diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index 457b0908..ee26dc9d 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -123,360 +123,87 @@ impl Operator for CudaConstant { } } -#[derive(Clone)] -pub struct CudaContiguous { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, - dyn_symbols: Vec, - dyn_map: *const FxHashMap, -} -crate::debug_type!(CudaContiguous); - -impl CudaContiguous { - pub fn new( - shape: ShapeTracker, - device: Arc, - dyn_map: *const FxHashMap, - ) -> Self { - let (idx, valid) = get_idx_valid_exps(shape); - let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape]); - let type_name = T::type_name(); - let code = format!( - " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp_a, int numel{rendered}) {{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < numel && ({valid}) != 0) {{ - out[idx] = inp_a[{idx}]; - }} -}}"); - Self { - function: compile_and_load_kernel(code, &device), - device, - _phantom: Default::default(), - dyn_symbols, - dyn_map, - } - } -} -impl Operator for CudaContiguous { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let res_shape = tensors[0].1.contiguous(); - let inp_size = res_shape.n_elements().to_usize().unwrap(); - let a = get_buffer_from_tensor::(&tensors[0].0); - let out = self.device.alloc_zeros::(inp_size).unwrap(); - let mut params = vec![ - (&out).as_kernel_param(), - a.as_kernel_param(), - inp_size.as_kernel_param(), - ]; - input_dyn_dims(&mut params, &self.dyn_symbols, self.dyn_map); - unsafe { - self.function - .clone() - .launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("input0".to_string())); - } - None - } -} - -#[derive(Clone)] -pub struct CudaLog2 { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaLog2); - -impl CudaLog2 { - pub fn new(device: Arc) -> Self { - let type_name = T::type_name(); - let code = format!( - " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = log2(inp[i]); - }} -}}" - ); - Self { - function: compile_and_load_kernel(code, &device), - device, - _phantom: Default::default(), - } - } -} - -impl Operator for CudaLog2 { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("log2(input0)".to_string())); - } - - None - } -} - -#[derive(Clone)] -pub struct CudaExp2 { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaExp2); - -impl CudaExp2 { - pub fn new(device: Arc) -> Self { - let type_name = T::type_name(); - let code = format!( - " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = exp2(inp[i]); - }} -}}" - ); - Self { - function: compile_and_load_kernel(code, &device), - device, - _phantom: Default::default(), - } - } -} -impl Operator for CudaExp2 { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("exp2(input0)".to_string())); - } - - None - } -} - -#[derive(Clone)] -pub struct CudaSqrt { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaSqrt); - -impl CudaSqrt { - pub fn new(device: Arc) -> Self { - let type_name = T::type_name(); - let code = format!( - " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = {}(inp[i]); - }} -}}", - if T::is_f32() { "sqrt" } else { "hsqrt" } - ); - Self { - function: compile_and_load_kernel(code, &device), - device, - _phantom: Default::default(), - } - } -} -impl Operator for CudaSqrt { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new(format!( - "{}(input0)", - if T::is_f32() { "sqrt" } else { "hsqrt" } - ))); - } - - None - } -} - -#[derive(Clone)] -pub struct CudaSin { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaSin); - -impl CudaSin { - pub fn new(device: Arc) -> Self { - let type_name = T::type_name(); - Self { - function: compile_and_load_kernel( - format!( +#[macro_export] +macro_rules! cuda_unary_op { + ($op: expr, $op_name: ident) => { + #[derive(Clone)] + pub struct $op_name { + function: CudaFunction, + device: Arc, + dyn_symbols: Vec, + dyn_map: *const FxHashMap, + _phantom: PhantomData, + } + + impl $op_name { + pub fn new( + shape: ShapeTracker, + device: Arc, + dyn_map: *const FxHashMap, + ) -> Self { + let (idx_exp, valid_exp) = get_idx_valid_exps(shape); + let (dyn_symbols, rendered) = render_dyn_dim_inputs(&[shape]); + let type_name = T::type_name(); + let code = format!( " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = sin(inp[i]); - }} -}}" - ), - &device, - ), - device, - _phantom: Default::default(), + #include \"cuda_fp16.h\" + extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel{rendered}) {{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel && {valid_exp} != 0) {{ + out[idx] = {}(inp[{idx_exp}]); + }} + }}", $op + ); + Self { + function: compile_and_load_kernel(code, &device), + device, + dyn_symbols, + dyn_map, + _phantom: Default::default(), + } + } } - } -} -impl Operator for CudaSin { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } + impl Operator for $op_name { + fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { + let inp = get_buffer_from_tensor::(&tensors[0].0); + let inp_size = tensors[0].1.n_elements().to_usize().unwrap(); + let out = self.device.alloc_zeros::(inp_size).unwrap(); + let mut params = vec![ + (&out).as_kernel_param(), + inp.as_kernel_param(), + inp_size.as_kernel_param(), + ]; + input_dyn_dims(&mut params, &self.dyn_symbols, self.dyn_map); + unsafe { + self.function + .clone() + .launch(LaunchConfig::for_num_elems(inp_size as u32), &mut params) + .unwrap(); + } + + vec![Tensor::new(CudaData(out))] + } - vec![Tensor::new(CudaData(out))] - } + fn custom(&mut self, key: &str, _: Box) -> Option> { + if key == "elementwise" { + return Some(Box::new(format!("{}(input0)", $op))); + } - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("sin(input0)".to_string())); + None + } } - None - } + $crate::debug_type!($op_name); + }; } -#[derive(Clone)] -pub struct CudaRecip { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaRecip); - -impl CudaRecip { - pub fn new(device: Arc) -> Self { - let type_name = T::type_name(); - let code = format!( - " -#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = {}(inp[i]); - }} -}}", - if T::is_f32() { "__frcp_rn" } else { "hrcp" } - ); - Self { - function: compile_and_load_kernel(code, &device), - device, - _phantom: Default::default(), - } - } -} - -impl Operator for CudaRecip { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new(format!( - "{}(input0)", - if T::is_f32() { "__frcp_rn" } else { "hrcp" } - ))); - } - - None - } -} +cuda_unary_op!("", CudaContiguous); +cuda_unary_op!("log2", CudaLog2); +cuda_unary_op!("exp2", CudaExp2); +cuda_unary_op!(if T::is_f32() { "sqrt" } else { "hsqrt" }, CudaSqrt); +cuda_unary_op!("sin", CudaSin); +cuda_unary_op!(if T::is_f32() { "__frcp_rn" } else { "hrcp" }, CudaRecip); #[derive(Clone)] pub struct CudaAdd { @@ -1062,11 +789,11 @@ impl Compiler for PrimitiveCompiler { let op = graph.node_weight(id).unwrap().as_any().type_id(); let op_ref = graph.graph.node_weight_mut(id).unwrap(); if is::(op) { - *op_ref = Box::new(CudaLog2::::new(dev.clone())); + *op_ref = Box::new(CudaLog2::::new(shapes[0], dev.clone(), &graph.dyn_map)); } else if is::(op) { - *op_ref = Box::new(CudaExp2::::new(dev.clone())); + *op_ref = Box::new(CudaExp2::::new(shapes[0], dev.clone(), &graph.dyn_map)); } else if is::(op) { - *op_ref = Box::new(CudaSin::::new(dev.clone())); + *op_ref = Box::new(CudaSin::::new(shapes[0], dev.clone(), &graph.dyn_map)); } else if let Some(c) = op_ref.as_any().downcast_ref::() { *op_ref = Box::new(CudaConstant::::new( dev.clone(), @@ -1074,9 +801,9 @@ impl Compiler for PrimitiveCompiler { &graph.dyn_map, )); } else if is::(op) { - *op_ref = Box::new(CudaRecip::::new(dev.clone())); + *op_ref = Box::new(CudaRecip::::new(shapes[0], dev.clone(), &graph.dyn_map)); } else if is::(op) { - *op_ref = Box::new(CudaSqrt::::new(dev.clone())); + *op_ref = Box::new(CudaSqrt::::new(shapes[0], dev.clone(), &graph.dyn_map)); } else if is::(op) { *op_ref = Box::new(CudaAdd::::new( shapes[0], diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index 6d05fce7..306928e0 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -290,182 +290,182 @@ impl Compiler for CudaQuantizedCompiler { } } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use cudarc::driver::CudaDevice; - use dfdx::{ - tensor::TensorFromVec, - tensor_ops::{PermuteTo, TryMatMul}, - }; - use luminal::{ - prelude::*, - tests::{assert_close, assert_close_precision, random_vec_rng}, - }; - use rand::{thread_rng, Rng}; - - use crate::{CudaData, CudaQuantizedCompiler}; - - #[repr(C, packed)] - struct BlockQ8_0 { - _d: f16, - _qs: [i8; 32], - } - - fn quantized_buffer(weights: &[BlockQ8_0], dev: &Arc) -> Tensor { - let n_bytes = std::mem::size_of_val(weights); - let buffer = dev - .htod_copy(unsafe { - Vec::::from_raw_parts(weights.as_ptr() as *mut u8, n_bytes, n_bytes) - }) - .unwrap(); - Tensor::new(CudaData(buffer)) - } - - #[test] - fn test_quantized_matvec() { - let mut rng = thread_rng(); - let mat_data: Vec = (0..1024 * 512).map(|_| rng.gen_range(0..5)).collect(); - let vec_data = random_vec_rng(1024, &mut rng); - let mut cx = Graph::new(); - let weights = cx.tensor::>().keep(); - let vec = cx.tensor::>().set(vec_data.clone()); - let mut out = vec.matmul(weights.permute()).retrieve(); - - // "Load" weights in 8bit - let blocks = mat_data - .chunks_exact(32) - .map(|chunk| { - let mut array = [0; 32]; - for (i, n) in chunk.iter().enumerate() { - array[i] = *n; - } - BlockQ8_0 { - _d: f16::from_f32(1.0), - _qs: array, - } - }) - .collect::>(); - cx.tensors.insert( - (weights.id, 0), - quantized_buffer(&blocks, &CudaDevice::new(0).unwrap()), - ); - - cx.compile( - CudaQuantizedCompiler::::new(vec![weights.id]), - &mut out, - ); - cx.execute(); - - let mut cx1 = Graph::new(); - let weights = cx1 - .tensor::>() - .set(mat_data.into_iter().map(|i| i as f32).collect::>()) - .keep(); - let vec = cx1.tensor::>().set(vec_data); - let out_32 = vec.matmul(weights.permute()).retrieve(); - cx1.execute(); - - assert_close(&out.data(), &out_32.data()); - blocks.leak(); // Segfaults without this - } - - #[test] - fn test_quantized_matmul() { - let mut rng = thread_rng(); - let mat_data: Vec = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect(); - let inp_mat_data = random_vec_rng(1024 * 16, &mut rng); - let mut cx = Graph::new(); - let weights = cx.tensor::>().keep(); - let inp_mat = cx.tensor::>().set(inp_mat_data.clone()); - let mut out = inp_mat.matmul(weights.permute()).retrieve(); - - // "Load" weights in 8bit - let blocks = mat_data - .chunks_exact(32) - .map(|chunk| { - let mut array = [0; 32]; - for (i, n) in chunk.iter().enumerate() { - array[i] = *n; - } - BlockQ8_0 { - _d: f16::from_f32(1.0), - _qs: array, - } - }) - .collect::>(); - let dev = CudaDevice::new(0).unwrap(); - cx.tensors - .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); - - cx.compile( - CudaQuantizedCompiler::::new(vec![weights.id]), - &mut out, - ); - cx.execute(); - - let cpu = dfdx::tensor::Cpu::default(); - let d_a = cpu.tensor_from_vec( - mat_data.into_iter().map(|i| i as f32).collect::>(), - (dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>), - ); - let d_b = cpu.tensor_from_vec( - inp_mat_data, - (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), - ); - let d_c = d_b.matmul(d_a.permute()); - assert_close(&out.data(), &d_c.as_vec()); - blocks.leak(); // Segfaults without this - } - - #[test] - fn test_quantized_matmul_fp16() { - let mut rng = thread_rng(); - let mat_data: Vec = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect(); - let inp_mat_data = random_vec_rng(1024 * 16, &mut rng); - let mut cx = Graph::new(); - let weights = cx.tensor::>().keep(); - let inp_mat = cx.tensor::>().set(inp_mat_data.clone()); - let mut out = inp_mat.matmul(weights.permute()).retrieve(); - - // "Load" weights in 8bit - let blocks = mat_data - .chunks_exact(32) - .map(|chunk| { - let mut array = [0; 32]; - for (i, n) in chunk.iter().enumerate() { - array[i] = *n; - } - BlockQ8_0 { - _d: f16::from_f32(1.0), - _qs: array, - } - }) - .collect::>(); - let dev = CudaDevice::new(0).unwrap(); - cx.tensors - .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); - - cx.compile( - CudaQuantizedCompiler::::new(vec![weights.id]), - &mut out, - ); - cx.execute(); - - let cpu = dfdx::tensor::Cpu::default(); - let d_a = cpu.tensor_from_vec( - mat_data.into_iter().map(|i| i as f32).collect::>(), - (dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>), - ); - let d_b = cpu.tensor_from_vec( - inp_mat_data, - (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), - ); - let d_c = d_b.matmul(d_a.permute()); - assert_close_precision(&out.data(), &d_c.as_vec(), 1.0); - // This is imprecise currently because we accumulate in fp16 in the matmul. TODO: accumulate in fp32 and convert before saving to dest - - blocks.leak(); // Segfaults without this - } -} +// #[cfg(test)] +// mod tests { +// use std::sync::Arc; + +// use cudarc::driver::CudaDevice; +// use dfdx::{ +// tensor::TensorFromVec, +// tensor_ops::{PermuteTo, TryMatMul}, +// }; +// use luminal::{ +// prelude::*, +// tests::{assert_close, assert_close_precision, random_vec_rng}, +// }; +// use rand::{thread_rng, Rng}; + +// use crate::{CudaData, CudaQuantizedCompiler}; + +// #[repr(C, packed)] +// struct BlockQ8_0 { +// _d: f16, +// _qs: [i8; 32], +// } + +// fn quantized_buffer(weights: &[BlockQ8_0], dev: &Arc) -> Tensor { +// let n_bytes = std::mem::size_of_val(weights); +// let buffer = dev +// .htod_copy(unsafe { +// Vec::::from_raw_parts(weights.as_ptr() as *mut u8, n_bytes, n_bytes) +// }) +// .unwrap(); +// Tensor::new(CudaData(buffer)) +// } + +// #[test] +// fn test_quantized_matvec() { +// let mut rng = thread_rng(); +// let mat_data: Vec = (0..1024 * 512).map(|_| rng.gen_range(0..5)).collect(); +// let vec_data = random_vec_rng(1024, &mut rng); +// let mut cx = Graph::new(); +// let weights = cx.tensor::>().keep(); +// let vec = cx.tensor::>().set(vec_data.clone()); +// let mut out = vec.matmul(weights.permute()).retrieve(); + +// // "Load" weights in 8bit +// let blocks = mat_data +// .chunks_exact(32) +// .map(|chunk| { +// let mut array = [0; 32]; +// for (i, n) in chunk.iter().enumerate() { +// array[i] = *n; +// } +// BlockQ8_0 { +// _d: f16::from_f32(1.0), +// _qs: array, +// } +// }) +// .collect::>(); +// cx.tensors.insert( +// (weights.id, 0), +// quantized_buffer(&blocks, &CudaDevice::new(0).unwrap()), +// ); + +// cx.compile( +// CudaQuantizedCompiler::::new(vec![weights.id]), +// &mut out, +// ); +// cx.execute(); + +// let mut cx1 = Graph::new(); +// let weights = cx1 +// .tensor::>() +// .set(mat_data.into_iter().map(|i| i as f32).collect::>()) +// .keep(); +// let vec = cx1.tensor::>().set(vec_data); +// let out_32 = vec.matmul(weights.permute()).retrieve(); +// cx1.execute(); + +// assert_close(&out.data(), &out_32.data()); +// blocks.leak(); // Segfaults without this +// } + +// #[test] +// fn test_quantized_matmul() { +// let mut rng = thread_rng(); +// let mat_data: Vec = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect(); +// let inp_mat_data = random_vec_rng(1024 * 16, &mut rng); +// let mut cx = Graph::new(); +// let weights = cx.tensor::>().keep(); +// let inp_mat = cx.tensor::>().set(inp_mat_data.clone()); +// let mut out = inp_mat.matmul(weights.permute()).retrieve(); + +// // "Load" weights in 8bit +// let blocks = mat_data +// .chunks_exact(32) +// .map(|chunk| { +// let mut array = [0; 32]; +// for (i, n) in chunk.iter().enumerate() { +// array[i] = *n; +// } +// BlockQ8_0 { +// _d: f16::from_f32(1.0), +// _qs: array, +// } +// }) +// .collect::>(); +// let dev = CudaDevice::new(0).unwrap(); +// cx.tensors +// .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); + +// cx.compile( +// CudaQuantizedCompiler::::new(vec![weights.id]), +// &mut out, +// ); +// cx.execute(); + +// let cpu = dfdx::tensor::Cpu::default(); +// let d_a = cpu.tensor_from_vec( +// mat_data.into_iter().map(|i| i as f32).collect::>(), +// (dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>), +// ); +// let d_b = cpu.tensor_from_vec( +// inp_mat_data, +// (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), +// ); +// let d_c = d_b.matmul(d_a.permute()); +// assert_close(&out.data(), &d_c.as_vec()); +// blocks.leak(); // Segfaults without this +// } + +// #[test] +// fn test_quantized_matmul_fp16() { +// let mut rng = thread_rng(); +// let mat_data: Vec = (0..(1024 * 512)).map(|_| rng.gen_range(0..5)).collect(); +// let inp_mat_data = random_vec_rng(1024 * 16, &mut rng); +// let mut cx = Graph::new(); +// let weights = cx.tensor::>().keep(); +// let inp_mat = cx.tensor::>().set(inp_mat_data.clone()); +// let mut out = inp_mat.matmul(weights.permute()).retrieve(); + +// // "Load" weights in 8bit +// let blocks = mat_data +// .chunks_exact(32) +// .map(|chunk| { +// let mut array = [0; 32]; +// for (i, n) in chunk.iter().enumerate() { +// array[i] = *n; +// } +// BlockQ8_0 { +// _d: f16::from_f32(1.0), +// _qs: array, +// } +// }) +// .collect::>(); +// let dev = CudaDevice::new(0).unwrap(); +// cx.tensors +// .insert((weights.id, 0), quantized_buffer(&blocks, &dev)); + +// cx.compile( +// CudaQuantizedCompiler::::new(vec![weights.id]), +// &mut out, +// ); +// cx.execute(); + +// let cpu = dfdx::tensor::Cpu::default(); +// let d_a = cpu.tensor_from_vec( +// mat_data.into_iter().map(|i| i as f32).collect::>(), +// (dfdx::shapes::Const::<512>, dfdx::shapes::Const::<1024>), +// ); +// let d_b = cpu.tensor_from_vec( +// inp_mat_data, +// (dfdx::shapes::Const::<16>, dfdx::shapes::Const::<1024>), +// ); +// let d_c = d_b.matmul(d_a.permute()); +// assert_close_precision(&out.data(), &d_c.as_vec(), 1.0); +// // This is imprecise currently because we accumulate in fp16 in the matmul. TODO: accumulate in fp32 and convert before saving to dest + +// blocks.leak(); // Segfaults without this +// } +// } diff --git a/crates/luminal_cuda/src/unary.rs b/crates/luminal_cuda/src/unary.rs index bd416f46..0b45f547 100644 --- a/crates/luminal_cuda/src/unary.rs +++ b/crates/luminal_cuda/src/unary.rs @@ -12,7 +12,8 @@ use luminal::{ use crate::{ binary::CudaSub, - compile_and_load_kernel, constant, get_buffer_from_tensor, get_idx_valid_exps, input_dyn_dims, + compile_and_load_kernel, constant, cuda_unary_op, get_buffer_from_tensor, get_idx_valid_exps, + input_dyn_dims, prim::{ CudaAdd, CudaConstant, CudaContiguous, CudaExp2, CudaMaxReduce, CudaMul, CudaRecip, CudaSin, CudaSqrt, CudaSumReduce, @@ -397,62 +398,7 @@ impl Compiler for StdNormCompiler { } } -#[derive(Clone)] -pub struct CudaExp { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaExp); - -impl CudaExp { - fn new(device: Arc) -> Self { - let type_name = T::type_name(); - Self { - function: compile_and_load_kernel( - format!( - "#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = exp(inp[i]); - }} -}}" - ), - &device, - ), - device, - _phantom: Default::default(), - } - } -} - -impl Operator for CudaExp { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("exp(input0)".to_string())); - } - - None - } -} +cuda_unary_op!("exp", CudaExp); #[derive(Default, Debug)] pub struct CudaExpCompiler(PhantomData); @@ -483,7 +429,7 @@ impl Compiler for CudaExpCompiler { .as_data() .unwrap(); let exp = graph - .add_op(CudaExp::::new(dev.clone())) + .add_op(CudaExp::::new(src_shape, dev.clone(), &graph.dyn_map)) .input(s.get(&inp), 0, src_shape) .finish(); @@ -499,62 +445,8 @@ impl Compiler for CudaExpCompiler { } } -/// Special kernel for cos -#[derive(Clone)] -pub struct CudaCos { - function: CudaFunction, - device: Arc, - _phantom: PhantomData, -} -crate::debug_type!(CudaCos); - -impl CudaCos { - fn new(device: Arc) -> Self { - let type_name = T::type_name(); - Self { - function: compile_and_load_kernel( - format!( - "#include \"cuda_fp16.h\" -extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{ - int i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < numel) {{ - out[i] = cos(inp[i]); - }} -}}" - ), - &device, - ), - device, - _phantom: Default::default(), - } - } -} -impl Operator for CudaCos { - fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec { - let inp = get_buffer_from_tensor::(&tensors[0].0); - let inp_size = tensors[0].1.n_physical_elements().to_usize().unwrap(); - let mut out = self.device.alloc_zeros::(inp_size).unwrap(); - unsafe { - self.function - .clone() - .launch( - LaunchConfig::for_num_elems(inp_size as u32), - (&mut out, inp, inp_size), - ) - .unwrap(); - } - - vec![Tensor::new(CudaData(out))] - } - - fn custom(&mut self, key: &str, _: Box) -> Option> { - if key == "elementwise" { - return Some(Box::new("cos(input0)".to_string())); - } - - None - } -} +// Special kernel for cos +cuda_unary_op!("cos", CudaCos); #[derive(Default, Debug)] pub struct CudaCosCompiler(PhantomData); @@ -588,7 +480,7 @@ impl Compiler for CudaCosCompiler { .unwrap() .2; let cos = graph - .add_op(CudaCos::::new(dev.clone())) + .add_op(CudaCos::::new(shape, dev.clone(), &graph.dyn_map)) .input(s.get(&inp), 0, shape) .finish();