Skip to content

Commit

Permalink
linear bias error example
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Jun 17, 2024
1 parent f61d53f commit 2e8a102
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 11 deletions.
18 changes: 18 additions & 0 deletions examples/bias_test/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "bias_test"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
metal = ["dep:luminal_metal"]
cuda = ["dep:luminal_cuda"]

[dependencies]
luminal = { path = "../.." }
luminal_training = { path = "../../crates/luminal_training" }
luminal_nn = { path = "../../crates/luminal_nn" }
rand = "0.8.5"
luminal_metal = { path = "../../crates/luminal_metal", optional = true }
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true }
51 changes: 51 additions & 0 deletions examples/bias_test/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use luminal::prelude::*;
use luminal_nn::Linear;
use luminal_training::Autograd;

pub struct LinearBiased {
pub linear: Linear<1, 2>,
pub bias: GraphTensor<R1<2>>,
}

impl SerializeModule for LinearBiased {
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.module("linear", &self.linear);
s.tensor("bias", self.bias);
}
}

impl InitModule for LinearBiased {
fn initialize(cx: &mut Graph) -> Self {
use rand::Rng;
let mut rng = rand::thread_rng();
Self {
linear: Linear::initialize(cx),
bias: cx.named_tensor("Bias").set(
(0..2)
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
}
}
}

impl Module<GraphTensor<R1<1>>> for LinearBiased {
type Output = GraphTensor<R1<2>>;
fn forward(&self, x: GraphTensor<R1<1>>) -> Self::Output {
let x: GraphTensor<R1<2>> = self.linear.forward(x);
let bias: GraphTensor<R1<2>> = self.bias.expand::<R1<2>, Axis<0>>();
// let bias: GraphTensor<R1<2>> = self.bias.clone();
x + bias
}
}

fn main() {
let mut cx = Graph::new();
let model = LinearBiased::initialize(&mut cx);
let input = cx.tensor::<R1<1>>();
let output = model.forward(input).retrieve();
let loss = output.sum_reduce().retrieve();

let weights = params(&model);
let _grads = cx.compile(Autograd::new(&weights, loss), ());
}
6 changes: 3 additions & 3 deletions src/hl_ops/reduction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ mod tests {
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>();
a.set(a_data.clone());
let b = a.sum_reduce::<_, LAxis<1>>();
let b = a.sum_reduce::<R1<2>, LAxis<1>>();
b.retrieve();

cx.execute();
Expand All @@ -114,7 +114,7 @@ mod tests {
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>();
a.set(a_data.clone());
let b = a.max_reduce::<_, LAxis<1>>();
let b = a.max_reduce::<R1<2>, LAxis<1>>();
b.retrieve();

cx.execute();
Expand All @@ -132,7 +132,7 @@ mod tests {
let a_data = random_vec(6);
let a = cx.tensor::<R2<2, 3>>();
a.set(a_data.clone());
let b = a.mean_reduce::<_, LAxis<1>>();
let b = a.mean_reduce::<R1<2>, LAxis<1>>();
b.retrieve();

cx.execute();
Expand Down
2 changes: 1 addition & 1 deletion src/shape/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub trait ReduceShape<Ax>: Sized + HasAxes<Ax> + ReduceShapeTo<Self::Reduced, Ax
type Reduced: Shape + BroadcastShapeTo<Self, Ax>;
}

impl ReduceShapeTo<(), Axis<0>> for () {}
impl<S: Shape + HasAxes<AnyAxis>, AnyAxis> ReduceShapeTo<S, AnyAxis> for S {}
impl ReduceShape<Axis<0>> for () {
type Reduced = ();
}
Expand Down
28 changes: 21 additions & 7 deletions src/tests/test_prim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,15 @@ fn test_sum_reduce() {
let a = cx
.tensor::<R3<2, 2, 3>>()
.set([[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]]);
let b = a.sum_reduce::<_, crate::prelude::Axis<1>>().retrieve();
let c = a.sum_reduce::<_, crate::prelude::Axis<0>>().retrieve();
let d = a.sum_reduce::<_, crate::prelude::Axis<2>>().retrieve();
let b = a
.sum_reduce::<R2<2, 3>, crate::prelude::Axis<1>>()
.retrieve();
let c = a
.sum_reduce::<R2<2, 3>, crate::prelude::Axis<0>>()
.retrieve();
let d = a
.sum_reduce::<R2<2, 2>, crate::prelude::Axis<2>>()
.retrieve();
cx.execute();

let d_dev = Cpu::default();
Expand All @@ -290,7 +296,9 @@ fn test_sum_reduce2() {
[[34.4, -96.0, 144.0], [43.0, 560.0, 180.0]],
[[39.6, -120.0, 180.0], [49.5, 700.0, 225.0]],
]]);
let b = a.sum_reduce::<_, crate::prelude::Axis<3>>().retrieve();
let b = a
.sum_reduce::<R3<1, 2, 2>, crate::prelude::Axis<3>>()
.retrieve();
cx.execute();

let d_dev = Cpu::default();
Expand All @@ -309,9 +317,15 @@ fn test_max_reduce() {
let a = cx
.tensor::<R3<2, 2, 3>>()
.set([[[1., 2., 3.], [1., 2., 3.]], [[1., 2., 3.], [1., 2., 3.]]]);
let b = a.max_reduce::<_, crate::prelude::Axis<1>>().retrieve();
let c = a.max_reduce::<_, crate::prelude::Axis<0>>().retrieve();
let d = a.max_reduce::<_, crate::prelude::Axis<2>>().retrieve();
let b = a
.max_reduce::<R2<2, 3>, crate::prelude::Axis<1>>()
.retrieve();
let c = a
.max_reduce::<R2<2, 3>, crate::prelude::Axis<0>>()
.retrieve();
let d = a
.max_reduce::<R2<2, 2>, crate::prelude::Axis<2>>()
.retrieve();
cx.execute();

let d_dev = Cpu::default();
Expand Down

0 comments on commit 2e8a102

Please sign in to comment.