Skip to content

Commit

Permalink
Reogranized nn module
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Apr 1, 2024
1 parent 784fe20 commit 21c1e72
Show file tree
Hide file tree
Showing 31 changed files with 209 additions and 358 deletions.
8 changes: 3 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@ env:
CARGO_TERM_COLOR: always

jobs:
cpu_test:
name: CPU Tests
unit_test:
name: Unit Tests
runs-on: ubuntu-latest
timeout-minutes: 20

steps:
- uses: actions/checkout@v3
- name: Build
run: cargo build --no-default-features --verbose
- name: Run tests
run: cargo test --no-default-features --verbose
run: cargo test --workspace --verbose
# macos_test:
# name: MacOS Tests
# runs-on: macos-13
Expand Down
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
workspace = { members = ["crates/luminal_training"] }
[workspace]
members = [ "crates/luminal_nn","crates/luminal_training" ]
exclude = [ "crates/luminal_metal","crates/luminal_cuda", "examples/llama", "examples/mistral", "examples/simple" ]

[package]
name = "luminal"
Expand Down
16 changes: 16 additions & 0 deletions crates/luminal_nn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "luminal_nn"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
itertools = "0.12.1"
luminal = {path="../.."}
rustc-hash = "1.1.0"
rand = "0.8.5"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
paste = "1.0.14"
8 changes: 4 additions & 4 deletions src/nn/activation.rs → crates/luminal_nn/src/activation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use luminal::prelude::*;

/// Rectified Linear Unit activation function
pub struct ReLU;
Expand Down Expand Up @@ -87,12 +87,12 @@ impl<S: ConstShape> Module<GraphTensor<S>> for Tanh {
#[cfg(test)]
mod tests {
use super::ReLU;
use crate::{
nn::linear::Linear,
use crate::Linear;
use dfdx::prelude::{Module as DfdxModule, *};
use luminal::{
prelude::{Module, *},
tests::assert_close,
};
use dfdx::prelude::{Module as DfdxModule, *};

#[test]
fn test_relu_and_linear() {
Expand Down
10 changes: 5 additions & 5 deletions src/nn/convolution.rs → crates/luminal_nn/src/convolution.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use luminal::prelude::*;
use rand::{thread_rng, Rng};

pub struct Conv1D<
Expand Down Expand Up @@ -48,7 +48,7 @@ impl<
> SerializeModule
for Conv1D<CHANNELS_IN, CHANNELS_OUT, KERNEL, STRIDE, DILATION, CHANNELS_IN_TIMES_KERNEL>
{
fn serialize(&self, s: &mut crate::serialization::Serializer) {
fn serialize(&self, s: &mut luminal::serialization::Serializer) {
s.tensor("weight", self.weight);
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ impl<
CHANNELS_IN_TIMES_KERNELX_KERNELY,
>
{
fn serialize(&self, s: &mut crate::serialization::Serializer) {
fn serialize(&self, s: &mut luminal::serialization::Serializer) {
s.tensor("weight", self.weight);
}
}
Expand Down Expand Up @@ -218,8 +218,8 @@ impl<

#[cfg(test)]
mod tests {
use super::Conv1D;
use crate::{nn::convolution::Conv2D, prelude::*, tests::assert_close};
use super::{Conv1D, Conv2D};
use luminal::{prelude::*, tests::assert_close};

#[test]
fn test_conv1d_simple() {
Expand Down
8 changes: 4 additions & 4 deletions src/nn/embedding.rs → crates/luminal_nn/src/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use luminal::prelude::*;

pub struct Embedding<const N: usize, const DIM: usize> {
pub weight: GraphTensor<R2<N, DIM>>,
Expand All @@ -13,7 +13,7 @@ impl<const A: usize, const B: usize> InitModule for Embedding<A, B> {
}

impl<const A: usize, const B: usize> SerializeModule for Embedding<A, B> {
fn serialize(&self, s: &mut crate::serialization::Serializer) {
fn serialize(&self, s: &mut luminal::serialization::Serializer) {
s.tensor("weight", self.weight);
}
}
Expand Down Expand Up @@ -49,11 +49,11 @@ mod tests {
tensor::{Cpu, TensorFromVec},
};

use crate::prelude::Module;
use luminal::prelude::Module;

use super::Embedding;
use dfdx::nn::BuildOnDevice;
crate::test_imports!();
luminal::test_imports!();

#[test]
fn test_embedding() {
Expand Down
45 changes: 45 additions & 0 deletions crates/luminal_nn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use luminal::prelude::*;

mod activation;
pub use activation::*;
mod convolution;
pub use convolution::*;
mod embedding;
pub use embedding::*;
mod linear;
pub use linear::*;
mod norm;
pub use norm::*;
mod transformer;
pub use transformer::*;

pub struct Repeated<T, const N: usize> {
pub modules: Vec<T>,
}

impl<T: InitModule, const N: usize> InitModule for Repeated<T, N> {
fn initialize(cx: &mut Graph) -> Self {
Self {
modules: (0..N).map(|_| InitModule::initialize(cx)).collect(),
}
}
}

impl<T: SerializeModule, const N: usize> SerializeModule for Repeated<T, N> {
fn serialize(&self, s: &mut Serializer) {
for (i, l) in self.modules.iter().enumerate() {
s.module(&format!("layer{i}"), l);
}
}
}

impl<I, T: Module<I, Output = I>, const N: usize> Module<I> for Repeated<T, N> {
type Output = I;

fn forward(&self, mut input: I) -> Self::Output {
for m in &self.modules {
input = m.forward(input);
}
input
}
}
6 changes: 3 additions & 3 deletions src/nn/linear.rs → crates/luminal_nn/src/linear.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rand::{thread_rng, Rng};

use crate::prelude::*;
use luminal::prelude::*;

/// A simple unbiased linear layer
pub struct Linear<const A: usize, const B: usize> {
Expand All @@ -24,7 +24,7 @@ impl<const A: usize, const B: usize> InitModule for Linear<A, B> {
}

impl<const A: usize, const B: usize> SerializeModule for Linear<A, B> {
fn serialize(&self, s: &mut crate::serialization::Serializer) {
fn serialize(&self, s: &mut luminal::serialization::Serializer) {
s.tensor("weight", self.weight);
}
}
Expand All @@ -43,7 +43,7 @@ where
#[cfg(test)]
mod tests {
use super::Linear;
use crate::{prelude::*, tests::assert_close};
use luminal::{prelude::*, tests::assert_close};
#[test]
fn test_linear() {
let mut cx = Graph::new();
Expand Down
4 changes: 2 additions & 2 deletions src/nn/norm.rs → crates/luminal_nn/src/norm.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{marker::PhantomData, ops::Mul};

use crate::prelude::*;
use luminal::prelude::*;

/// A simple layer norm layer. Calls `tensor.layer_norm::<DIM>()`.
#[derive(Default)]
pub struct LayerNorm<Ax: Axes>(PhantomData<Ax>);

impl<Ax: Axes> InitModule for LayerNorm<Ax> {
fn initialize(_: &mut crate::prelude::Graph) -> Self {
fn initialize(_: &mut luminal::prelude::Graph) -> Self {
Self::default()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ops::Mul;

use crate::{nn::linear::Linear, prelude::*};
use crate::Linear;
use luminal::prelude::*;

// This is still single head attention because I need a runtime reshape, like the try_reshape in dfdx
pub struct MultiHeadSelfAttention<
Expand Down Expand Up @@ -184,11 +185,11 @@ impl<

#[cfg(test)]
mod tests {
use crate::{
use dfdx::prelude::{Module as DfdxModule, *};
use luminal::{
prelude::{Module, *},
tests::assert_close,
};
use dfdx::prelude::{Module as DfdxModule, *};

use super::MultiHeadSelfAttention;
#[test]
Expand All @@ -212,8 +213,8 @@ mod tests {
.weight
.set(vec![1., 22., 3., 1., 2., 3., 1., 2., 3.]);

let a = cx.tensor::<(Dyn<'d'>, crate::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, crate::shape::Const<3>)>();
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
let b = model.forward((e, a, e));

a.set_dyn(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{
nn::{activation::ReLU, linear::Linear},
prelude::*,
};
use crate::{Linear, ReLU};
use luminal::prelude::*;

use super::attention::MultiHeadSelfAttention;

Expand Down Expand Up @@ -180,7 +178,7 @@ mod tests {
tensor_ops::PermuteTo,
};

use crate::{
use luminal::{
prelude::{Module, *},
tests::assert_close,
};
Expand Down Expand Up @@ -241,8 +239,8 @@ mod tests {
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);

let a = cx.tensor::<(Dyn<'d'>, crate::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, crate::shape::Const<3>)>();
let a = cx.tensor::<(Dyn<'d'>, Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, Const<3>)>();
let b = model.forward((a, e));

a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use crate::{
nn::{activation::ReLU, linear::Linear, Repeated},
prelude::*,
};
use crate::{Linear, ReLU, Repeated};
use luminal::prelude::*;

use super::attention::MultiHeadSelfAttention;

Expand Down Expand Up @@ -74,7 +72,7 @@ mod tests {
tensor_ops::PermuteTo,
};

use crate::{
use luminal::{
prelude::{Module, *},
tests::assert_close,
};
Expand Down Expand Up @@ -116,7 +114,7 @@ mod tests {
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);

let a = cx
.tensor::<(Dyn<'s'>, crate::shape::Const<3>)>()
.tensor::<(Dyn<'s'>, luminal::shape::Const<3>)>()
.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
let b = model.forward(a).retrieve();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::prelude::*;
use luminal::prelude::*;

mod attention;
pub use attention::*;
Expand Down Expand Up @@ -83,10 +83,11 @@ mod tests {
tensor_ops::PermuteTo,
};

use crate::{
use luminal::{
prelude::{Module, *},
tests::assert_close,
};
use rand::{thread_rng, Rng};

use super::Transformer;
#[test]
Expand Down Expand Up @@ -174,8 +175,8 @@ mod tests {
.weight
.set(vec![-1., 12., 3., -1., 2., -3., 11., 2., 3., 3., -1., 2.]);

let a = cx.tensor::<(Dyn<'d'>, crate::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, crate::shape::Const<3>)>();
let a = cx.tensor::<(Dyn<'d'>, luminal::shape::Const<3>)>();
let e = cx.tensor::<(Dyn<'e'>, luminal::shape::Const<3>)>();
let b = model.forward((a, e));

a.set_dyn(vec![-1., 2., 3., 3., 3., -1.], &[2, 3]);
Expand Down Expand Up @@ -388,4 +389,35 @@ mod tests {

assert_close(&b.data(), &d_b.as_vec());
}

#[test]
fn test_serialization() {
let mut rng = thread_rng();
let enc_data = (0..(24 * 32)).map(|_| rng.gen()).collect::<Vec<f32>>();
let trg_data = (0..(20 * 32)).map(|_| rng.gen()).collect::<Vec<f32>>();

let mut cx = Graph::new();
let model: Transformer<32, 5, 4, 4, 3, 2> = InitModule::initialize(&mut cx);
let enc = cx.tensor::<R2<24, 32>>().set(enc_data.clone()).keep();
let trg = cx.tensor::<R2<20, 32>>().set(trg_data.clone()).keep();
let mut out1 = model.forward((trg, enc)).retrieve();
cx.compile(CPUCompiler::default(), &mut out1);

cx.execute_no_delete();

let param_dict = ParamDictSaver.save(&model, &mut cx);
let out1 = out1.data();

let mut cx = Graph::new();
let model: Transformer<32, 5, 4, 4, 3, 2> = InitModule::initialize(&mut cx);
ParamDictLoader::new(param_dict).load(&model, &mut cx);
let enc = cx.tensor::<R2<24, 32>>().set(enc_data);
let trg = cx.tensor::<R2<20, 32>>().set(trg_data);
let mut out2 = model.forward((trg, enc)).retrieve();

cx.compile(CPUCompiler::default(), &mut out2);
cx.execute();

assert_close(&out1, &out2.data());
}
}
1 change: 1 addition & 0 deletions crates/luminal_training/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ rustc-hash = "1.1.0"
dfdx = { version = "0.13", features = ["f16"] }
paste = "1.0.14"
rand = "0.8.5"
luminal_nn = { path = "../luminal_nn" }
Loading

0 comments on commit 21c1e72

Please sign in to comment.