From 00be879d6ac219b2fe32cb0188dafbf12111328f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Sat, 23 Mar 2024 17:50:58 -0700 Subject: [PATCH 1/2] Fix CUDA compilation by defining Output type for Compiler trait --- crates/luminal_cuda/src/binary.rs | 3 +++ crates/luminal_cuda/src/matmul.rs | 1 + crates/luminal_cuda/src/other.rs | 1 + crates/luminal_cuda/src/prim.rs | 2 ++ crates/luminal_cuda/src/quantized.rs | 1 + 5 files changed, 8 insertions(+) diff --git a/crates/luminal_cuda/src/binary.rs b/crates/luminal_cuda/src/binary.rs index e029a010..9ab6d056 100644 --- a/crates/luminal_cuda/src/binary.rs +++ b/crates/luminal_cuda/src/binary.rs @@ -85,6 +85,7 @@ impl Operator for CudaSub { pub struct SubtractionCompiler(PhantomData); impl Compiler for SubtractionCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, _: To) { let dev = CudaDevice::new(0).unwrap(); let (lhs, rhs) = (node(), node()); @@ -208,6 +209,7 @@ impl Operator for CudaEqual { pub struct EqualCompiler(PhantomData); impl Compiler for EqualCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, _: To) { let dev = CudaDevice::new(0).unwrap(); let one = constant::(1.); @@ -343,6 +345,7 @@ impl Operator for CudaGather { pub struct GatherCompiler(PhantomData); impl Compiler for GatherCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, _: To) { let dev = CudaDevice::new(0).unwrap(); let indexes = node(); diff --git a/crates/luminal_cuda/src/matmul.rs b/crates/luminal_cuda/src/matmul.rs index d0af07ba..4e93e160 100644 --- a/crates/luminal_cuda/src/matmul.rs +++ b/crates/luminal_cuda/src/matmul.rs @@ -109,6 +109,7 @@ impl Compiler for MatMulCompiler where CudaData: Data, { + type Output = (); fn compile(&self, graph: &mut Graph, mut ids: To) { let dev = CudaDevice::new(0).unwrap(); // Look for the matmul pattern diff --git a/crates/luminal_cuda/src/other.rs b/crates/luminal_cuda/src/other.rs index c616a7b3..f656cfea 100644 --- a/crates/luminal_cuda/src/other.rs +++ b/crates/luminal_cuda/src/other.rs @@ -75,6 +75,7 @@ impl Operator for CudaARange { pub struct ARangeCompiler(PhantomData); impl Compiler for ARangeCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, _: To) { let dev = CudaDevice::new(0).unwrap(); // TODO: Make sure this actually checks the shape transformations to ensure pooling happens diff --git a/crates/luminal_cuda/src/prim.rs b/crates/luminal_cuda/src/prim.rs index dc0df081..3fe6b820 100644 --- a/crates/luminal_cuda/src/prim.rs +++ b/crates/luminal_cuda/src/prim.rs @@ -937,6 +937,7 @@ impl Operator for CudaMaxReduce { pub struct PrimitiveCompiler(PhantomData); impl Compiler for PrimitiveCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, mut ids: To) { let dev = CudaDevice::new(0).unwrap(); // Go through the graph and insert copy ops @@ -1146,6 +1147,7 @@ impl Compiler for PrimitiveCompiler { pub struct CopyCompiler(PhantomData); impl Compiler for CopyCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, mut ids: To) { for (first, second) in graph .edge_indices() diff --git a/crates/luminal_cuda/src/quantized.rs b/crates/luminal_cuda/src/quantized.rs index c2a340be..714b9ab1 100644 --- a/crates/luminal_cuda/src/quantized.rs +++ b/crates/luminal_cuda/src/quantized.rs @@ -259,6 +259,7 @@ impl CudaQuantizedCompiler { } impl Compiler for CudaQuantizedCompiler { + type Output = (); fn compile(&self, graph: &mut Graph, mut remap: To) { let device = CudaDevice::new(0).unwrap(); let mut weight_ids = self.0.clone(); From 6ed065d4b77c7e7143d13cdf4a6d8b465682b6bc Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Sat, 23 Mar 2024 18:05:32 -0700 Subject: [PATCH 2/2] Remove nn submodule imports and rely on prelude re-exporting symbols --- examples/llama/src/model.rs | 1 - examples/mistral/src/model.rs | 1 - examples/simple/src/main.rs | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/llama/src/model.rs b/examples/llama/src/model.rs index 92684118..5cd82668 100644 --- a/examples/llama/src/model.rs +++ b/examples/llama/src/model.rs @@ -11,7 +11,6 @@ pub const HEADS: usize = 32; pub const LAYERS: usize = 32; use luminal::{ - nn::{embedding::Embedding, norm::RMSNorm}, prelude::*, shape::symbolic::{BigExpression, Expression}, }; diff --git a/examples/mistral/src/model.rs b/examples/mistral/src/model.rs index 870f5461..5b0bc7b4 100644 --- a/examples/mistral/src/model.rs +++ b/examples/mistral/src/model.rs @@ -1,7 +1,6 @@ use std::{marker::PhantomData, ops::Div}; use luminal::{ - nn::{embedding::Embedding, norm::RMSNorm}, prelude::{binary::F32Pow, *}, shape::symbolic::{BigExpression, Expression}, }; diff --git a/examples/simple/src/main.rs b/examples/simple/src/main.rs index 957ec799..26be445b 100644 --- a/examples/simple/src/main.rs +++ b/examples/simple/src/main.rs @@ -1,4 +1,4 @@ -use luminal::{nn::linear::Linear, prelude::*}; +use luminal::prelude::*; fn main() { // Create a new graph