Skip to content

Commit

Permalink
add a passing and failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Jun 22, 2024
1 parent f61d53f commit 9f98da3
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,64 @@ impl<T: CudaFloat> Compiler for GatherCompiler<T> {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
luminal::test_imports!();

type TR0 = GraphTensor<R0>;
type TR1<const A: usize> = GraphTensor<R1<A>>;
type TR2<const A: usize, const B: usize> = GraphTensor<R2<A, B>>;

#[test]
fn test_gather_compiler_r0() {
const CLASSES: usize = 10;
const TARGET: usize = 3;

let mut cx = Graph::new();
let mut input: TR0 = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();

let input_one_hot: TR1<CLASSES> = input
.graph()
.arange::<LConst<CLASSES>>()
.equals(input.expand());
let input_embedding: TR1<TARGET> = (input_one_hot.expand::<R2<CLASSES, TARGET>, _>()
* embedder)
.sum_reduce::<_, LAxis<0>>();
let mut loss: TR0 = input_embedding.sum_reduce();
let mut weights = vec![embedder.id];

cx.compile(
crate::CudaCompiler::<f32>::default(),
(&mut input, &mut loss, &mut weights),
);
}

#[test]
fn test_gather_compiler_r1() {
const CLASSES: usize = 10;
const TARGET: usize = 3;

let mut cx = Graph::new();
let mut input: TR1<1> = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();

let input_one_hot: TR2<1, CLASSES> = input
.graph()
.arange::<LConst<CLASSES>>()
.expand::<R2<1, CLASSES>, _>()
.equals(input.expand());
let input_embedding: TR2<1, TARGET> = (input_one_hot.expand::<R3<1, CLASSES, TARGET>, _>()
* embedder.expand())
.sum_reduce::<_, LAxis<1>>();
let mut loss: TR0 = input_embedding.sum_reduce();
let mut weights = vec![embedder.id];

cx.compile(
crate::CudaCompiler::<f32>::default(),
(&mut input, &mut loss, &mut weights),
);
}
}

0 comments on commit 9f98da3

Please sign in to comment.