diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index c79fe9fa..e01e365b 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -396,3 +396,64 @@ impl Compiler for GatherCompiler { } } } + +#[cfg(test)] +mod tests { + use super::*; + luminal::test_imports!(); + + type TR0 = GraphTensor; + type TR1 = GraphTensor>; + type TR2 = GraphTensor>; + + #[test] + fn test_gather_compiler_r0() { + const CLASSES: usize = 10; + const TARGET: usize = 3; + + let mut cx = Graph::new(); + let mut input: TR0 = cx.tensor(); + let embedder: TR2 = cx.tensor(); + + let input_one_hot: TR1 = input + .graph() + .arange::>() + .equals(input.expand()); + let input_embedding: TR1 = (input_one_hot.expand::, _>() + * embedder) + .sum_reduce::<_, LAxis<0>>(); + let mut loss: TR0 = input_embedding.sum_reduce(); + let mut weights = vec![embedder.id]; + + cx.compile( + crate::CudaCompiler::::default(), + (&mut input, &mut loss, &mut weights), + ); + } + + #[test] + fn test_gather_compiler_r1() { + const CLASSES: usize = 10; + const TARGET: usize = 3; + + let mut cx = Graph::new(); + let mut input: TR1<1> = cx.tensor(); + let embedder: TR2 = cx.tensor(); + + let input_one_hot: TR2<1, CLASSES> = input + .graph() + .arange::>() + .expand::, _>() + .equals(input.expand()); + let input_embedding: TR2<1, TARGET> = (input_one_hot.expand::, _>() + * embedder.expand()) + .sum_reduce::<_, LAxis<1>>(); + let mut loss: TR0 = input_embedding.sum_reduce(); + let mut weights = vec![embedder.id]; + + cx.compile( + crate::CudaCompiler::::default(), + (&mut input, &mut loss, &mut weights), + ); + } +}