Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 27, 2024
2 parents fb84e93 + 8bf379b commit 076a165
Show file tree
Hide file tree
Showing 25 changed files with 1,217 additions and 66 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ num-traits = "0.2.16"
petgraph = "0.6.4"
rand = "0.8.5"
urlencoding = "2.1.2"
webbrowser = "0.8.10"
webbrowser = "1.0.0"
dyn-clone = "1.0.12"
half = "*"
tinyvec = "1.6.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# [luminal](https://luminalai.com)
![image](https://raw.githubusercontent.com/jafioti/luminal/main/dag.jpeg)
![image](https://github.com/jafioti/luminal/blob/main/docs/dag.jpeg)
[![Website](https://img.shields.io/badge/Docs-Website-blue?style=for-the-badge&color=0D9373)](https://luminalai.com)
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
Expand Down
39 changes: 29 additions & 10 deletions crates/luminal_cpu/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,35 @@ pub struct GatherCompiler;
impl Compiler for GatherCompiler {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let arange = op::<ARange>();
let eq = unary::<Equal>(arange);
let inp = node();
let mul = binary::<Mul>(inp.clone(), eq.clone());
let indexes = node();
let eq = binary::<Equal>(indexes.clone(), op::<ARange>());
let embedding = node();
let mul = binary::<Mul>(embedding.clone(), eq.clone());
let sum_reduce = unary::<SumReduce>(mul.clone());
let mut s = sum_reduce.clone().search(graph);
while s.next_match() {
if s.check_no_delete(&[sum_reduce.id]) {
if s.check_no_delete(&[embedding.id]) {
continue;
}
let emb_shape = graph
.edges_connecting(s.get(&embedding), s.get(&mul))
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&eq))
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
let embed_dim = graph
.graph
.edges_connecting(s.get(&inp), s.get(&mul))
.edges_connecting(s.get(&embedding), s.get(&mul))
.next()
.unwrap()
.weight()
Expand All @@ -218,11 +234,14 @@ impl Compiler for GatherCompiler {
.shape()[2]
.to_usize()
.unwrap();
let gather = graph.add_op(Gather { embed_dim }).finish();
move_incoming_edge(s.get(&eq), gather, &mut graph.graph);
graph.safe_remove_node(s.get(&eq), 1);
move_incoming_edge(s.get(&mul), gather, &mut graph.graph);

let gather = graph
.add_op(Gather { embed_dim })
.input(s.get(&indexes), 0, index_shape)
.input(s.get(&embedding), 0, emb_shape)
.finish();
move_outgoing_edge(s.get(&sum_reduce), gather, &mut graph.graph);
graph.remove_node(s.get(&sum_reduce));
s.try_delete();
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/luminal_cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ luminal_cudarc = { version="0.10.0", features = [
itertools = "0.12.1"
rustc-hash = "1.1.0"
num-traits = "0.2.18"
fmt-derive = "0.1.1"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
Expand Down
10 changes: 6 additions & 4 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{marker::PhantomData, sync::Arc};

use fmt_derive::Debug;
use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig};

use luminal::{
Expand All @@ -16,14 +15,15 @@ use crate::{
render_dyn_dim_inputs, CudaData, CudaFloat,
};

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaSub<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaSub<T>);

impl<T: CudaFloat> CudaSub<T> {
pub fn new(
Expand Down Expand Up @@ -140,14 +140,15 @@ impl<T: CudaFloat> Compiler for SubtractionCompiler<T> {
}
}

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaEqual<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaEqual<T>);

impl<T: CudaFloat> CudaEqual<T> {
pub fn new(
Expand Down Expand Up @@ -263,13 +264,14 @@ impl<T: CudaFloat> Compiler for EqualCompiler<T> {
}
}

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaGather<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
pub embed_dim: usize,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaGather<T>);

impl<T: CudaFloat> CudaGather<T> {
pub fn new(device: Arc<CudaDevice>, embed_dim: usize) -> Self {
Expand Down
13 changes: 12 additions & 1 deletion crates/luminal_cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ fn render_dyn_dim_inputs(shapes: &[ShapeTracker]) -> (Vec<char>, String) {
.into_iter()
.flat_map(|i| [i.0.into(), i.1.into()]),
)
.chain(st.slices.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
.chain(st.mask.into_iter().flat_map(|i| [i.0.into(), i.1.into()]))
})
.flat_map(|d| d.to_symbols())
.unique()
Expand Down Expand Up @@ -235,3 +235,14 @@ fn compile_and_load_kernel(mut code: String, device: &Arc<CudaDevice>) -> CudaFu
}
device.get_func(&name, &name).unwrap()
}

#[macro_export]
macro_rules! debug_type {
($t: ty) => {
impl<T> std::fmt::Debug for $t {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, stringify!($t))
}
}
};
}
4 changes: 2 additions & 2 deletions crates/luminal_cuda/src/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{marker::PhantomData, sync::Arc};

use fmt_derive::Debug;
use luminal_cudarc::{
cublas::{sys::cublasOperation_t::*, CudaBlas},
driver::{CudaDevice, DevicePtr, DevicePtrMut},
Expand All @@ -16,8 +15,9 @@ use luminal::{
prelude::*,
};

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct Matmul<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);
crate::debug_type!(Matmul<T>);

impl<T: CudaFloat> Operator for Matmul<T> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
Expand Down
7 changes: 3 additions & 4 deletions crates/luminal_cuda/src/other.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::{marker::PhantomData, sync::Arc};

use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig};

use fmt_derive::Debug;
use luminal::prelude::*;
use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig};
use rustc_hash::FxHashMap;

use crate::{
Expand All @@ -13,14 +11,15 @@ use crate::{
CudaData, CudaFloat,
};

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaARange<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
pub size: BigExpression,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaARange<T>);

impl<T: CudaFloat> CudaARange<T> {
pub fn new(
Expand Down
Loading

0 comments on commit 076a165

Please sign in to comment.