Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Mar 29, 2024
2 parents c95c514 + a30212b commit bf64f3c
Show file tree
Hide file tree
Showing 8 changed files with 9 additions and 3 deletions.
3 changes: 3 additions & 0 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl<T: CudaFloat> Operator for CudaSub<T> {
pub struct SubtractionCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for SubtractionCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let (lhs, rhs) = (node(), node());
Expand Down Expand Up @@ -208,6 +209,7 @@ impl<T: CudaFloat> Operator for CudaEqual<T> {
pub struct EqualCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for EqualCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let one = constant::<T>(1.);
Expand Down Expand Up @@ -343,6 +345,7 @@ impl<T: CudaFloat> Operator for CudaGather<T> {
pub struct GatherCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for GatherCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
let indexes = node();
Expand Down
1 change: 1 addition & 0 deletions crates/luminal_cuda/src/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl<T: CudaFloat + 'static> Compiler for MatMulCompiler<T>
where
CudaData<T>: Data,
{
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
let dev = CudaDevice::new(0).unwrap();
// Look for the matmul pattern
Expand Down
1 change: 1 addition & 0 deletions crates/luminal_cuda/src/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl<T: CudaFloat> Operator for CudaARange<T> {
pub struct ARangeCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for ARangeCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let dev = CudaDevice::new(0).unwrap();
// TODO: Make sure this actually checks the shape transformations to ensure pooling happens
Expand Down
2 changes: 2 additions & 0 deletions crates/luminal_cuda/src/prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,7 @@ impl<T: CudaFloat> Operator for CudaMaxReduce<T> {
pub struct PrimitiveCompiler<T>(PhantomData<T>);

impl<T: CudaFloat> Compiler for PrimitiveCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
let dev = CudaDevice::new(0).unwrap();
// Go through the graph and insert copy ops
Expand Down Expand Up @@ -1146,6 +1147,7 @@ impl<T: CudaFloat> Compiler for PrimitiveCompiler<T> {
pub struct CopyCompiler<T>(PhantomData<T>);

impl<T: CudaFloat> Compiler for CopyCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut ids: To) {
for (first, second) in graph
.edge_indices()
Expand Down
1 change: 1 addition & 0 deletions crates/luminal_cuda/src/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ impl<T> CudaQuantizedCompiler<T> {
}

impl<T: CudaFloat + Default> Compiler for CudaQuantizedCompiler<T> {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, mut remap: To) {
let device = CudaDevice::new(0).unwrap();
let mut weight_ids = self.0.clone();
Expand Down
1 change: 0 additions & 1 deletion examples/llama/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pub const HEADS: usize = 32;
pub const LAYERS: usize = 32;

use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
prelude::*,
shape::symbolic::{BigExpression, Expression},
};
Expand Down
1 change: 0 additions & 1 deletion examples/mistral/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{marker::PhantomData, ops::Div};

use luminal::{
nn::{embedding::Embedding, norm::RMSNorm},
prelude::{binary::F32Pow, *},
shape::symbolic::{BigExpression, Expression},
};
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use luminal::{nn::linear::Linear, prelude::*};
use luminal::prelude::*;

fn main() {
// Create a new graph
Expand Down

0 comments on commit bf64f3c

Please sign in to comment.