Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/jafioti/luminal
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 29, 2024
2 parents 15ba813 + 0efbc51 commit e2be277
Show file tree
Hide file tree
Showing 34 changed files with 3,332 additions and 330 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ num-traits = "0.2.16"
petgraph = "0.6.4"
rand = "0.8.5"
urlencoding = "2.1.2"
webbrowser = "0.8.10"
webbrowser = "1.0.0"
dyn-clone = "1.0.12"
half = "*"
tinyvec = "1.6.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Website](https://img.shields.io/badge/Docs-Website-blue?style=for-the-badge&color=0D9373)](https://luminalai.com)
[![CI Status](https://img.shields.io/github/actions/workflow/status/jafioti/luminal/test.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions)
[![Current Crates.io Version](https://img.shields.io/crates/v/luminal.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/luminal)
[![](https://dcbadge.vercel.app/api/server/VQf3j8WWNd)](https://discord.gg/VQf3j8WWNd)
[![discord](https://dcbadge.vercel.app/api/server/VQf3j8WWNd)](https://discord.gg/VQf3j8WWNd)

**Deep learning at the speed of light.**

Expand Down
39 changes: 29 additions & 10 deletions crates/luminal_cpu/src/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,35 @@ pub struct GatherCompiler;
impl Compiler for GatherCompiler {
type Output = ();
fn compile<To: ToIdsMut>(&self, graph: &mut Graph, _: To) {
let arange = op::<ARange>();
let eq = unary::<Equal>(arange);
let inp = node();
let mul = binary::<Mul>(inp.clone(), eq.clone());
let indexes = node();
let eq = binary::<Equal>(indexes.clone(), op::<ARange>());
let embedding = node();
let mul = binary::<Mul>(embedding.clone(), eq.clone());
let sum_reduce = unary::<SumReduce>(mul.clone());
let mut s = sum_reduce.clone().search(graph);
while s.next_match() {
if s.check_no_delete(&[sum_reduce.id]) {
if s.check_no_delete(&[embedding.id]) {
continue;
}
let emb_shape = graph
.edges_connecting(s.get(&embedding), s.get(&mul))
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
let index_shape = graph
.edges_connecting(s.get(&indexes), s.get(&eq))
.next()
.unwrap()
.weight()
.as_data()
.unwrap()
.2;
let embed_dim = graph
.graph
.edges_connecting(s.get(&inp), s.get(&mul))
.edges_connecting(s.get(&embedding), s.get(&mul))
.next()
.unwrap()
.weight()
Expand All @@ -218,11 +234,14 @@ impl Compiler for GatherCompiler {
.shape()[2]
.to_usize()
.unwrap();
let gather = graph.add_op(Gather { embed_dim }).finish();
move_incoming_edge(s.get(&eq), gather, &mut graph.graph);
graph.safe_remove_node(s.get(&eq), 1);
move_incoming_edge(s.get(&mul), gather, &mut graph.graph);

let gather = graph
.add_op(Gather { embed_dim })
.input(s.get(&indexes), 0, index_shape)
.input(s.get(&embedding), 0, emb_shape)
.finish();
move_outgoing_edge(s.get(&sum_reduce), gather, &mut graph.graph);
graph.remove_node(s.get(&sum_reduce));
s.try_delete();
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/luminal_cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ luminal_cudarc = { version="0.10.0", features = [
itertools = "0.12.1"
rustc-hash = "1.1.0"
num-traits = "0.2.18"
fmt-derive = "0.1.1"
regex = "1.10.4"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
Expand Down
26 changes: 21 additions & 5 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::{marker::PhantomData, sync::Arc};
use std::{any::Any, marker::PhantomData, sync::Arc};

use fmt_derive::Debug;
use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig};

use luminal::{
Expand All @@ -16,14 +15,15 @@ use crate::{
render_dyn_dim_inputs, CudaData, CudaFloat,
};

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaSub<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaSub);

impl<T: CudaFloat> CudaSub<T> {
pub fn new(
Expand Down Expand Up @@ -80,6 +80,13 @@ impl<T: CudaFloat> Operator for CudaSub<T> {

vec![Tensor::new(CudaData(out))]
}

fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "elementwise" {
return Some(Box::new("input0 - input1".to_string()));
}
None
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -140,14 +147,15 @@ impl<T: CudaFloat> Compiler for SubtractionCompiler<T> {
}
}

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaEqual<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
dyn_symbols: Vec<char>,
dyn_map: *const FxHashMap<char, usize>,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaEqual);

impl<T: CudaFloat> CudaEqual<T> {
pub fn new(
Expand Down Expand Up @@ -204,6 +212,13 @@ impl<T: CudaFloat> Operator for CudaEqual<T> {

vec![Tensor::new(CudaData(out))]
}

fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "elementwise" {
return Some(Box::new("(float)(input0 == input1)".to_string()));
}
None
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -263,13 +278,14 @@ impl<T: CudaFloat> Compiler for EqualCompiler<T> {
}
}

#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct CudaGather<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
pub embed_dim: usize,
_phantom: PhantomData<T>,
}
crate::debug_type!(CudaGather);

impl<T: CudaFloat> CudaGather<T> {
pub fn new(device: Arc<CudaDevice>, embed_dim: usize) -> Self {
Expand Down
Loading

0 comments on commit e2be277

Please sign in to comment.