Skip to content

Commit

Permalink
Fixed many cuda bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 28, 2024
1 parent 8bf379b commit fa2b7ac
Show file tree
Hide file tree
Showing 9 changed files with 968 additions and 164 deletions.
15 changes: 14 additions & 1 deletion crates/luminal_cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ mod matmul;
mod other;
mod prim;
mod quantized;
mod unary;
pub use quantized::*;

#[cfg(test)]
#[macro_use]
mod tests;

use itertools::Itertools;
Expand All @@ -20,14 +22,25 @@ use std::{collections::hash_map::DefaultHasher, ffi::c_void, fmt::Write, hash::H

use luminal::{op::InputTensor, prelude::*};

/// Compile graphs to run on Metal-supported macOS devices in supported data formats
pub type CudaCompiler<T> = (
prim::PrimitiveCompiler<T>,
SpecialOpsCompiler<T>,
other::CopyCompiler<T>,
);

/// Compiler to replace metal ops with specialized variants
pub type SpecialOpsCompiler<T> = (
binary::SubtractionCompiler<T>,
binary::EqualCompiler<T>,
other::ARangeCompiler<T>,
binary::GatherCompiler<T>,
unary::CudaExpCompiler<T>,
unary::CudaCosCompiler<T>,
unary::MeanReduceCompiler<T>,
unary::StdNormCompiler<T>,
unary::SoftmaxCompiler<T>,
matmul::MatMulCompiler<T>,
prim::CopyCompiler<T>,
);

pub trait CudaFloat:
Expand Down
13 changes: 8 additions & 5 deletions crates/luminal_cuda/src/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ crate::debug_type!(Matmul<T>);
impl<T: CudaFloat> Operator for Matmul<T> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let a_strides = inp[0].1.strides();
let (batch_size, m, k, n) = (
a_shape
.iter()
Expand All @@ -33,6 +32,7 @@ impl<T: CudaFloat> Operator for Matmul<T> {
a_shape[a_shape.len() - 1].to_usize().unwrap() as i32,
b_shape[b_shape.len() - 1].to_usize().unwrap() as i32,
);
println!("{:?}", (batch_size, m, k, n));
let a = get_buffer_from_tensor::<T>(&inp[0].0);
let b = get_buffer_from_tensor::<T>(&inp[1].0);
let mut out = self
Expand All @@ -49,6 +49,9 @@ impl<T: CudaFloat> Operator for Matmul<T> {
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
};

let a_dims = inp[0].1.fake.iter().filter(|f| !**f).count();
let b_dims = inp[1].1.fake.iter().filter(|f| !**f).count();
if T::is_f32() {
unsafe {
luminal_cudarc::cublas::result::sgemm_strided_batched(
Expand All @@ -61,10 +64,10 @@ impl<T: CudaFloat> Operator for Matmul<T> {
&1.0_f32 as *const f32,
*b.device_ptr() as *const f32,
if b_row_major { n } else { k },
0,
if b_dims == 2 { 0 } else { (n * k) as i64 },
*a.device_ptr() as *const f32,
if a_row_major { k } else { m },
a_strides[0].to_usize().unwrap() as i64,
if a_dims == 2 { 0 } else { (m * k) as i64 },
&0.0_f32 as *const f32,
*out.device_ptr_mut() as *mut f32,
n,
Expand All @@ -85,10 +88,10 @@ impl<T: CudaFloat> Operator for Matmul<T> {
&f16::from_f32(1.0) as *const f16,
*b.device_ptr() as *const f16,
if b_row_major { n } else { k },
0,
if b_dims == 2 { 0 } else { (n * k) as i64 },
*a.device_ptr() as *const f16,
if a_row_major { k } else { m },
a_strides[0].to_usize().unwrap() as i64,
if a_dims == 2 { 0 } else { (m * k) as i64 },
&f16::from_f32(0.0) as *const f16,
*out.device_ptr_mut() as *mut f16,
n,
Expand Down
81 changes: 79 additions & 2 deletions crates/luminal_cuda/src/other.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use std::{marker::PhantomData, sync::Arc};

use luminal::prelude::*;
use itertools::Itertools;
use luminal::prelude::{petgraph::visit::EdgeRef, *};
use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig};
use rustc_hash::FxHashMap;

use crate::{
binary::CudaSub,
compile_and_load_kernel, constant,
prim::{CudaAdd, CudaContiguous, CudaSumReduce},
prim::{CudaAdd, CudaContiguous, CudaCopyFromDevice, CudaCopyToDevice, CudaSumReduce},
CudaData, CudaFloat,
};

Expand Down Expand Up @@ -120,3 +121,79 @@ impl<T: CudaFloat> Compiler for ARangeCompiler<T> {
}
}
}

// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up
#[derive(Debug, Default)]
pub struct CopyCompiler<T>(PhantomData<T>);

impl<T: CudaFloat> Compiler for CopyCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
for (first, second) in graph
.edge_indices()
.filter_map(|e| graph.edge_endpoints(e))
.filter(|(a, b)| {
(graph
.node_weight(*a)
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>()
&& graph
.node_weight(*b)
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>())
|| (graph
.node_weight(*a)
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>()
&& graph
.node_weight(*b)
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>())
})
.unique_by(|n| n.0)
.unique_by(|n| n.1)
.collect::<Vec<_>>()
{
if graph
.edges_directed(first, petgraph::Direction::Outgoing)
.filter(|e| graph.contains_node(e.target()))
.filter(|e| {
!graph
.node_weight(e.target())
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>()
&& !graph
.node_weight(e.target())
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>()
})
.count()
> 0
|| graph.no_delete.contains(&first)
{
continue;
}
let source = graph.get_sources(first)[0];
move_outgoing_edge(second, source.0, graph);
remap(second, source.0, &mut ids, graph);
graph.remove_node(second);
for dest in graph
.get_dests(first)
.iter()
.map(|(i, _)| *i)
.collect::<Vec<_>>()
{
move_outgoing_edge(dest, source.0, graph);
remap(dest, source.0, &mut ids, graph);
graph.remove_node(dest);
}
graph.remove_node(first);
}
}
}
88 changes: 7 additions & 81 deletions crates/luminal_cuda/src/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,20 @@ crate::debug_type!(CudaSin<T>);
impl<T: CudaFloat> CudaSin<T> {
pub fn new(device: Arc<CudaDevice>) -> Self {
let type_name = T::type_name();
let code = format!(
"
Self {
function: compile_and_load_kernel(
format!(
"
#include \"cuda_fp16.h\"
extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *inp, int numel) {{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {{
out[i] = sin(inp[i]);
}}
}}"
);
Self {
function: compile_and_load_kernel(code, &device),
),
&device,
),
device,
_phantom: Default::default(),
}
Expand Down Expand Up @@ -1108,79 +1110,3 @@ impl<T: CudaFloat> Compiler for PrimitiveCompiler<T> {
}
}
}

// Sometimes CopyTo -> CopyFrom and CopyFrom -> CopyTo patterns remain, so let's clean them up
#[derive(Debug, Default)]
pub struct CopyCompiler<T>(PhantomData<T>);

impl<T: CudaFloat> Compiler for CopyCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
for (first, second) in graph
.edge_indices()
.filter_map(|e| graph.edge_endpoints(e))
.filter(|(a, b)| {
(graph
.node_weight(*a)
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>()
&& graph
.node_weight(*b)
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>())
|| (graph
.node_weight(*a)
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>()
&& graph
.node_weight(*b)
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>())
})
.unique_by(|n| n.0)
.unique_by(|n| n.1)
.collect::<Vec<_>>()
{
if graph
.edges_directed(first, petgraph::Direction::Outgoing)
.filter(|e| graph.contains_node(e.target()))
.filter(|e| {
!graph
.node_weight(e.target())
.unwrap()
.as_any()
.is::<CudaCopyFromDevice<T>>()
&& !graph
.node_weight(e.target())
.unwrap()
.as_any()
.is::<CudaCopyToDevice<T>>()
})
.count()
> 0
|| graph.no_delete.contains(&first)
{
continue;
}
let source = graph.get_sources(first)[0];
move_outgoing_edge(second, source.0, graph);
remap(second, source.0, &mut ids, graph);
graph.remove_node(second);
for dest in graph
.get_dests(first)
.iter()
.map(|(i, _)| *i)
.collect::<Vec<_>>()
{
move_outgoing_edge(dest, source.0, graph);
remap(dest, source.0, &mut ids, graph);
graph.remove_node(dest);
}
graph.remove_node(first);
}
}
}
35 changes: 2 additions & 33 deletions crates/luminal_cuda/src/tests/fp16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,7 @@ use luminal::{
},
};

use crate::{binary_test, single_binary_test, single_unary_test, unary_test, CudaCompiler};

unary_test!(|a| a.sin(), |a| a.sin(), test_sin, f16);
unary_test!(|a| a.sqrt(), |a| a.sqrt(), test_sqrt, f16);
unary_test!(|a| a.recip(), |a| a.recip(), test_recip, f16);
unary_test!(|a| a * a, |a| a.clone() * a, test_square, f16);
single_unary_test!(|a| a.ln(), |a| a.ln(), test_ln, f16, 3); // For some reason ln fails on larger tensors

binary_test!(|a, b| a + b, |a, b| a + b, test_add, f16);
binary_test!(|a, b| a - b, |a, b| a - b, test_sub, f16);
binary_test!(|a, b| a * b, |a, b| a * b, test_mul, f16);
binary_test!(|a, b| a / b, |a, b| a * b.recip(), test_div, f16);
binary_test!(|a, b| a.max(b), |a, b| a.maximum(b), test_max, f16);
binary_test!(|a, b| a.min(b), |a, b| a.minimum(b), test_min, f16);
use crate::CudaCompiler;

#[test]
fn test_contiguous() {
Expand All @@ -54,24 +41,6 @@ fn test_contiguous() {
assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
}

#[test]
fn test_softmax() {
let mut cx = Graph::new();
let data = random_vec(12);
let a = cx.tensor::<R2<1, 12>>().set(data.clone());
let mut b = a.softmax::<LAxis<1>>().retrieve();
cx.compile(CudaCompiler::<f16>::default(), &mut b);
cx.execute();

let d_dev = Cpu::default();
let d_a = d_dev
.tensor_from_vec(data, (DConst::<1>, DConst::<12>))
.to_dtype::<f16>();
let d_b = d_a.softmax::<DAxis<1>>();

assert_close(&b.data(), &d_b.to_dtype::<f32>().as_vec());
}

#[test]
fn test_rotate() {
let mut cx = Graph::new();
Expand Down Expand Up @@ -814,7 +783,7 @@ fn test_pad_contig() {
.set_dyn(a_data, &[m, k])
.retrieve();
let mut b: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> = a
.pad(&[(0, 0.into()), (0, Expression::from(16) - 'K')])
.pad(&[(0, 0.into()), (0, Expression::from(24) - 'K')])
.contiguous()
.retrieve();
let mut c: GraphTensor<(Dyn<'M'>, Dyn<'K'>)> =
Expand Down
Loading

0 comments on commit fa2b7ac

Please sign in to comment.