Skip to content

Commit

Permalink
Updated cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 20, 2024
1 parent 1c40dd0 commit 1a43b74
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 184 deletions.
4 changes: 3 additions & 1 deletion crates/luminal_cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ luminal_cudarc = { version="0.10.0", features = [
itertools = "0.12.1"
rustc-hash = "1.1.0"
num-traits = "0.2.18"
fmt-derive = "0.1.1"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
rand = "0.8.5"
paste = "1.0.14"
paste = "1.0.14"
luminal_nn = {path="../../crates/luminal_nn"}
25 changes: 9 additions & 16 deletions crates/luminal_cuda/src/binary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{marker::PhantomData, sync::Arc};

use fmt_derive::Debug;
use luminal_cudarc::driver::{CudaDevice, CudaFunction, DeviceRepr, LaunchAsync, LaunchConfig};

use luminal::{
Expand All @@ -15,7 +16,7 @@ use crate::{
render_dyn_dim_inputs, CudaData, CudaFloat,
};

#[derive(LuminalEqTrue, LuminalPrint, Clone)]
#[derive(Clone, Debug)]
pub struct CudaSub<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
Expand Down Expand Up @@ -81,7 +82,7 @@ impl<T: CudaFloat> Operator for CudaSub<T> {
}
}

#[derive(LuminalPrint, Default)]
#[derive(Debug, Default)]
pub struct SubtractionCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for SubtractionCompiler<T> {
Expand Down Expand Up @@ -139,7 +140,7 @@ impl<T: CudaFloat> Compiler for SubtractionCompiler<T> {
}
}

#[derive(LuminalEqTrue, LuminalPrint, Clone)]
#[derive(Clone, Debug)]
pub struct CudaEqual<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
Expand Down Expand Up @@ -205,7 +206,7 @@ impl<T: CudaFloat> Operator for CudaEqual<T> {
}
}

#[derive(LuminalPrint, Default)]
#[derive(Debug, Default)]
pub struct EqualCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for EqualCompiler<T> {
Expand Down Expand Up @@ -262,7 +263,7 @@ impl<T: CudaFloat> Compiler for EqualCompiler<T> {
}
}

#[derive(LuminalPrint, Clone, LuminalEqFalse)]
#[derive(Clone, Debug)]
pub struct CudaGather<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
Expand Down Expand Up @@ -294,13 +295,7 @@ extern \"C\" __global__ void kernel({type_name} *out, const {type_name} *weights
impl<T: CudaFloat> Operator for CudaGather<T> {
fn process(&mut self, inputs: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
// Inp 1 should be Vec<f32> and inp 2 should be a CudaSlice<T>
let indexes = inputs[0]
.0
.borrowed()
.data
.as_any()
.downcast_ref::<Vec<f32>>()
.unwrap();
let indexes = inputs[0].0.borrowed().downcast_ref::<Vec<f32>>().unwrap();
let weights = get_buffer_from_tensor::<T>(&inputs[1].0);

let mut indexes_buffer = unsafe { self.device.alloc::<f32>(indexes.len()).unwrap() };
Expand Down Expand Up @@ -335,13 +330,11 @@ impl<T: CudaFloat> Operator for CudaGather<T> {
.unwrap();
}

vec![Tensor {
data: Box::new(CudaData(out)),
}]
vec![Tensor::new(CudaData(out))]
}
}

#[derive(LuminalPrint, Default)]
#[derive(Debug, Default)]
pub struct GatherCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for GatherCompiler<T> {
Expand Down
37 changes: 17 additions & 20 deletions crates/luminal_cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ use std::{collections::hash_map::DefaultHasher, ffi::c_void, fmt::Write, hash::H

use luminal::{op::InputTensor, prelude::*};

use self::symbolic::{BigExpression, Term};

pub type CudaCompiler<T> = (
prim::PrimitiveCompiler<T>,
binary::SubtractionCompiler<T>,
Expand Down Expand Up @@ -80,16 +78,6 @@ impl<T: CudaFloat> Data for CudaData<T> {
}
}

impl Data for CudaData<u8> {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

impl CudaFloat for f16 {
fn from_f32(a: f32) -> Self {
f16::from_f32(a)
Expand All @@ -105,6 +93,21 @@ impl CudaFloat for f16 {
}
}

impl CudaFloat for u8 {
fn from_f32(a: f32) -> Self {
a as u8
}
fn to_f32(self) -> f32 {
self as f32
}
fn is_f32() -> bool {
false
}
fn type_name() -> &'static str {
"uint8_t"
}
}

fn expr_to_cuda_string(expr: BigExpression) -> String {
let mut symbols = vec![];
for term in expr.terms {
Expand Down Expand Up @@ -195,14 +198,8 @@ fn hash<T: std::hash::Hash>(obj: T) -> u64 {
hasher.finish()
}

fn get_buffer_from_tensor<'a, T: 'static>(tensor: &'a InputTensor) -> &'a CudaSlice<T> {
&tensor
.borrowed()
.data
.as_any()
.downcast_ref::<CudaData<T>>()
.unwrap()
.0
fn get_buffer_from_tensor<'a, T: CudaFloat>(tensor: &'a InputTensor) -> &'a CudaSlice<T> {
&tensor.borrowed().downcast_ref::<CudaData<T>>().unwrap().0
}

fn input_dyn_dims(
Expand Down
3 changes: 2 additions & 1 deletion crates/luminal_cuda/src/matmul.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{marker::PhantomData, sync::Arc};

use fmt_derive::Debug;
use luminal_cudarc::{
cublas::{sys::cublasOperation_t::*, CudaBlas},
driver::{CudaDevice, DevicePtr, DevicePtrMut},
Expand All @@ -15,7 +16,7 @@ use luminal::{
prelude::*,
};

#[derive(LuminalPrint, LuminalEqFalse, Clone)]
#[derive(Clone, Debug)]
pub struct Matmul<T>(Arc<CudaBlas>, Arc<CudaDevice>, PhantomData<T>);

impl<T: CudaFloat> Operator for Matmul<T> {
Expand Down
11 changes: 5 additions & 6 deletions crates/luminal_cuda/src/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use std::{marker::PhantomData, sync::Arc};

use luminal_cudarc::driver::{CudaDevice, CudaFunction, LaunchAsync, LaunchConfig};

use luminal::{op::*, prelude::*, shape::symbolic::BigExpression};
use fmt_derive::Debug;
use luminal::prelude::*;
use rustc_hash::FxHashMap;

use crate::{
Expand All @@ -12,7 +13,7 @@ use crate::{
CudaData, CudaFloat,
};

#[derive(LuminalPrint, Clone, LuminalEqFalse)]
#[derive(Clone, Debug)]
pub struct CudaARange<T> {
function: CudaFunction,
device: Arc<CudaDevice>,
Expand Down Expand Up @@ -65,13 +66,11 @@ impl<T: CudaFloat> Operator for CudaARange<T> {
.unwrap();
}

vec![Tensor {
data: Box::new(CudaData(out)),
}]
vec![Tensor::new(CudaData(out))]
}
}

#[derive(LuminalPrint, Default)]
#[derive(Debug, Default)]
pub struct ARangeCompiler<T: CudaFloat>(PhantomData<T>);

impl<T: CudaFloat> Compiler for ARangeCompiler<T> {
Expand Down
Loading

0 comments on commit 1a43b74

Please sign in to comment.