Skip to content

Commit

Permalink
Fixed cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jul 25, 2024
1 parent 6590920 commit 50d1bb4
Show file tree
Hide file tree
Showing 20 changed files with 1,049 additions and 1,186 deletions.
8 changes: 5 additions & 3 deletions crates/luminal_cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ license = "MIT OR Apache-2.0"

[dependencies]
luminal = { path = "../.." }
cudarc = { version="0.11.1", features = [
cudarc = { version = "0.11.1", features = [
"f16",
"cuda-version-from-build-system",
]}
] }
itertools = "0.12.1"
rustc-hash = "1.1.0"
num-traits = "0.2.18"
regex = "1.10.4"
indicatif = "0.17.8"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
rand = "0.8.5"
paste = "1.0.14"
luminal_nn = {path="../../crates/luminal_nn"}
luminal_nn = { path = "../../crates/luminal_nn" }
candle-core = "0.5.0"
41 changes: 17 additions & 24 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ impl<T: CudaFloat> Compiler for GatherCompiler<T> {
.as_data()
.unwrap()
.2;
let embed_dim = emb_shape.shape().last().unwrap().to_usize().unwrap();
let embed_dim = emb_shape.dims().last().unwrap().to_usize().unwrap();
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&ind_copy))
.next()
Expand All @@ -402,27 +402,21 @@ mod tests {
use super::*;
luminal::test_imports!();

type TR0 = GraphTensor<R0>;
type TR1<const A: usize> = GraphTensor<R1<A>>;
type TR2<const A: usize, const B: usize> = GraphTensor<R2<A, B>>;

#[test]
fn test_gather_compiler_r0() {
const CLASSES: usize = 2;
const TARGET: usize = 1;

let mut cx = Graph::new();
let mut input: TR0 = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();
let mut input = cx.tensor(());
let embedder = cx.tensor((CLASSES, TARGET));

let input_one_hot: TR1<CLASSES> = input
let input_one_hot = input
.graph()
.arange::<LConst<CLASSES>>()
.equals(input.expand());
let input_embedding: TR1<TARGET> = (input_one_hot.expand::<R2<CLASSES, TARGET>, _>()
* embedder)
.sum_reduce::<_, LAxis<0>>();
let mut loss: TR0 = input_embedding.sum_reduce();
.arange(CLASSES)
.equals(input.expand(0, CLASSES));
let input_embedding = (input_one_hot.expand(1, TARGET) * embedder).sum_reduce(0);
let mut loss = input_embedding.sum_reduce(0);
let mut weights = vec![embedder.id];

cx.compile(
Expand All @@ -437,18 +431,17 @@ mod tests {
const TARGET: usize = 1;

let mut cx = Graph::new();
let mut input: TR1<1> = cx.tensor();
let embedder: TR2<CLASSES, TARGET> = cx.tensor();
let mut input = cx.tensor(1);
let embedder = cx.tensor((CLASSES, TARGET));

let input_one_hot: TR2<1, CLASSES> = input
let input_one_hot = input
.graph()
.arange::<LConst<CLASSES>>()
.expand::<R2<1, CLASSES>, _>()
.equals(input.expand());
let input_embedding: TR2<1, TARGET> = (input_one_hot.expand::<R3<1, CLASSES, TARGET>, _>()
* embedder.expand())
.sum_reduce::<_, LAxis<1>>();
let mut loss: TR0 = input_embedding.sum_reduce();
.arange(CLASSES)
.expand(0, 1)
.equals(input.expand(1, CLASSES));
let input_embedding =
(input_one_hot.expand(2, TARGET) * embedder.expand(0, 1)).sum_reduce(1);
let mut loss = input_embedding.sum_reduce(0);
let mut weights = vec![embedder.id];

cx.compile(
Expand Down
Loading

0 comments on commit 50d1bb4

Please sign in to comment.