diff --git a/crates/luminal_nn/src/convolution.rs b/crates/luminal_nn/src/convolution.rs index 29fce529..7087e4cf 100644 --- a/crates/luminal_nn/src/convolution.rs +++ b/crates/luminal_nn/src/convolution.rs @@ -2,32 +2,29 @@ use luminal::prelude::*; use rand::{thread_rng, Rng}; pub struct Conv1D< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNEL: usize, - const STRIDE: usize, - const DILATION: usize, - const CHANNELS_IN_TIMES_KERNEL: usize, + const STRIDE: usize = KERNEL, + const DILATION: usize = 0, > { - pub weight: GraphTensor>, + pub weight: GraphTensor>, } impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNEL: usize, const STRIDE: usize, const DILATION: usize, - const CHANNELS_IN_TIMES_KERNEL: usize, - > InitModule - for Conv1D + > InitModule for Conv1D { fn initialize(cx: &mut Graph) -> Self { // Init weight as uniform(-1, 1) let mut rng = thread_rng(); Self { weight: cx.named_tensor("Weight").set( - (0..(CHANNELS_IN * CHANNELS_OUT * KERNEL)) + (0..(CH_IN * CH_OUT * KERNEL)) .map(|_| rng.gen_range(-1_f32..1_f32)) .collect::>(), ), @@ -36,14 +33,12 @@ impl< } impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNEL: usize, const STRIDE: usize, const DILATION: usize, - const CHANNELS_IN_TIMES_KERNEL: usize, - > SerializeModule - for Conv1D + > SerializeModule for Conv1D { fn serialize(&self, s: &mut luminal::module::Serializer) { s.tensor("weight", self.weight); @@ -52,74 +47,63 @@ impl< // Single impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNEL: usize, const STRIDE: usize, const DILATION: usize, - const CHANNELS_IN_TIMES_KERNEL: usize, - > Conv1D + > Conv1D { pub fn forward( &self, - input: GraphTensor>, - ) -> GraphTensor> { - self.weight.matmul( - input - .pool_last_dim::>( - KERNEL.into(), - STRIDE.into(), - DILATION, - ) - .permute::<_, Axes3<0, 2, 1>>() - .reshape::>(), - ) + input: GraphTensor>, + ) -> GraphTensor> { + self.weight + .dyn_reshape::<(Const, Dyn<'-'>)>(vec![CH_OUT.into(), (CH_IN * KERNEL).into()]) + .matmul( + input + .pool_last_dim::>( + KERNEL.into(), + STRIDE.into(), + DILATION, + ) + .permute::<_, Axes3<0, 2, 1>>() + .dyn_reshape(vec![(CH_IN * KERNEL).into(), DIM_OUT.into()]), + ) } } pub struct Conv2D< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNELX: usize, const KERNELY: usize, - const STRIDEX: usize, - const STRIDEY: usize, - const DILATIONX: usize, - const DILATIONY: usize, - const CHANNELS_IN_TIMES_KERNELX_KERNELY: usize, + const STRIDEX: usize = KERNELX, + const STRIDEY: usize = KERNELY, + const DILATIONX: usize = 0, + const DILATIONY: usize = 0, > { - pub weight: GraphTensor>, + pub weight: GraphTensor>, } impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNELX: usize, const KERNELY: usize, const STRIDEX: usize, const STRIDEY: usize, const DILATIONX: usize, const DILATIONY: usize, - const CHANNELS_IN_TIMES_KERNELX_KERNELY: usize, > InitModule - for Conv2D< - CHANNELS_IN, - CHANNELS_OUT, - KERNELX, - KERNELY, - STRIDEX, - STRIDEY, - DILATIONX, - DILATIONY, - CHANNELS_IN_TIMES_KERNELX_KERNELY, - > + for Conv2D { fn initialize(cx: &mut Graph) -> Self { // Init weight as uniform(-1, 1) let mut rng = thread_rng(); Self { weight: cx.named_tensor("Weight").set( - (0..(CHANNELS_IN * CHANNELS_OUT * KERNELX * KERNELY)) + (0..(CH_IN * CH_OUT * KERNELX * KERNELY)) .map(|_| rng.gen_range(-1_f32..1_f32)) .collect::>(), ), @@ -128,27 +112,16 @@ impl< } impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNELX: usize, const KERNELY: usize, const STRIDEX: usize, const STRIDEY: usize, const DILATIONX: usize, const DILATIONY: usize, - const CHANNELS_IN_TIMES_KERNELX_KERNELY: usize, > SerializeModule - for Conv2D< - CHANNELS_IN, - CHANNELS_OUT, - KERNELX, - KERNELY, - STRIDEX, - STRIDEY, - DILATIONX, - DILATIONY, - CHANNELS_IN_TIMES_KERNELX_KERNELY, - > + for Conv2D { fn serialize(&self, s: &mut luminal::module::Serializer) { s.tensor("weight", self.weight); @@ -157,56 +130,50 @@ impl< // Single impl< - const CHANNELS_IN: usize, - const CHANNELS_OUT: usize, + const CH_IN: usize, + const CH_OUT: usize, const KERNELX: usize, const KERNELY: usize, const STRIDEX: usize, const STRIDEY: usize, const DILATIONX: usize, const DILATIONY: usize, - const CHANNELS_IN_TIMES_KERNELX_KERNELY: usize, - > - Conv2D< - CHANNELS_IN, - CHANNELS_OUT, - KERNELX, - KERNELY, - STRIDEX, - STRIDEY, - DILATIONX, - DILATIONY, - CHANNELS_IN_TIMES_KERNELX_KERNELY, - > + > Conv2D { pub fn forward< const DIMX_IN: usize, const DIMY_IN: usize, const DIMX_OUT: usize, const DIMY_OUT: usize, - const DIMX_TIMES_DIMY_OUT: usize, >( &self, - input: GraphTensor>, - ) -> GraphTensor> { + input: GraphTensor>, + ) -> GraphTensor> { let input_pooled = input - .pool_last_dim::>( + .pool_last_dim::>( KERNELY.into(), STRIDEY.into(), DILATIONY, ) .permute::<_, Axes4<0, 2, 3, 1>>() - .pool_last_dim::>( + .pool_last_dim::>( KERNELX.into(), STRIDEX.into(), DILATIONX, ) .permute::<_, Axes5<0, 4, 2, 3, 1>>() - .reshape::>(); + .dyn_reshape::<(_, Dyn<'-'>)>(vec![ + (CH_IN * KERNELX * KERNELY).into(), + (DIMX_OUT * DIMY_OUT).into(), + ]); self.weight + .dyn_reshape::<(Const, Dyn<'-'>)>(vec![ + CH_OUT.into(), + (CH_IN * KERNELX * KERNELY).into(), + ]) .matmul(input_pooled) - .reshape::>() + .reshape::>() } } @@ -219,27 +186,19 @@ mod tests { fn test_conv1d_simple() { let mut cx = Graph::new(); - const CHANNELS_IN: usize = 1; - const CHANNELS_OUT: usize = 1; + const CH_IN: usize = 1; + const CH_OUT: usize = 1; const KERNEL: usize = 2; - const STRIDE: usize = 2; - const DILATION: usize = 0; + const STRIDE: usize = KERNEL; const DIM_IN: usize = 6; - const DIM_OUT: usize = ((DIM_IN - (DILATION + 1) * (KERNEL - 1) - 1) / STRIDE) + 1; - const CHANNELS_IN_TIMES_KERNEL: usize = CHANNELS_IN * KERNEL; + const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1; - let model: Conv1D< - CHANNELS_IN, - CHANNELS_OUT, - KERNEL, - KERNEL, - DILATION, - CHANNELS_IN_TIMES_KERNEL, - > = Conv1D::initialize(&mut cx); - model.weight.set(vec![0.0316, -0.2057]); + let model = Conv1D::::initialize(&mut cx); + model.weight.set([[[0.0316, -0.2057]]]); - let inp1 = cx.tensor::>(); - inp1.set(vec![3., 0., 9., 6., 0., 6.]); + let inp1 = cx + .tensor::>() + .set([[3., 0., 9., 6., 0., 6.]]); let out1 = model.forward::(inp1).retrieve(); cx.execute(); @@ -251,23 +210,14 @@ mod tests { fn test_conv1d() { let mut cx = Graph::new(); - const CHANNELS_IN: usize = 8; - const CHANNELS_OUT: usize = 4; + const CH_IN: usize = 8; + const CH_OUT: usize = 4; const KERNEL: usize = 2; const STRIDE: usize = 2; - const DILATION: usize = 0; const DIM_IN: usize = 12; - const DIM_OUT: usize = ((DIM_IN - (DILATION + 1) * (KERNEL - 1) - 1) / STRIDE) + 1; - const CHANNELS_IN_TIMES_KERNEL: usize = CHANNELS_IN * KERNEL; + const DIM_OUT: usize = ((DIM_IN - (KERNEL - 1) - 1) / STRIDE) + 1; - let model: Conv1D< - CHANNELS_IN, - CHANNELS_OUT, - KERNEL, - KERNEL, - DILATION, - CHANNELS_IN_TIMES_KERNEL, - > = Conv1D::initialize(&mut cx); + let model = Conv1D::::initialize(&mut cx); model.weight.set(vec![ -0.1700, -0.2000, 0.1000, -0.0200, 0.1000, 0.0200, -0.2100, -0.2300, -0.0600, 0.1500, 0.1200, 0.1000, 0.1800, 0.0600, -0.1700, -0.0400, 0.1000, -0.0200, -0.1700, 0.1000, @@ -278,7 +228,7 @@ mod tests { 0.0700, -0.1200, 0.1400, 0.2200, ]); - let inp1 = cx.tensor::>(); + let inp1 = cx.tensor::>(); inp1.set(vec![ 1., 2., 6., 4., 8., 1., 6., 0., 1., 0., 6., 4., 3., 4., 9., 3., 8., 8., 5., 5., 0., 4., 2., 7., 6., 4., 2., 2., 8., 0., 7., 3., 0., 0., 7., 2., 3., 3., 1., 9., 5., 4., 5., 5., @@ -305,22 +255,20 @@ mod tests { fn test_conv2d() { let mut cx = Graph::new(); - const CHANNELS_IN: usize = 5; - const CHANNELS_OUT: usize = 2; + const CH_IN: usize = 5; + const CH_OUT: usize = 2; const KERNELX: usize = 2; const KERNELY: usize = 2; - const STRIDEX: usize = 2; - const STRIDEY: usize = 2; + const STRIDEX: usize = KERNELX; + const STRIDEY: usize = KERNELY; const DILATIONX: usize = 0; const DILATIONY: usize = 0; const DIMX_IN: usize = 16; const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1; const DIMY_IN: usize = 9; const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1; - const DIMX_TIMES_DIMY_OUT: usize = DIMX_OUT * DIMY_OUT; - const CHANNELS_IN_TIMES_KERNELX_KERNELY: usize = CHANNELS_IN * KERNELX * KERNELY; - let inp1 = cx.tensor::>(); + let inp1 = cx.tensor::>(); inp1.set(vec![ 8., 8., 5., 7., 0., 6., 5., 3., 0., 7., 0., 6., 6., 7., 7., 5., 0., 6., 9., 4., 0., 8., 8., 5., 7., 6., 2., 8., 9., 5., 0., 3., 1., 1., 8., 4., 1., 1., 5., 6., 9., 3., 2., 9., @@ -357,7 +305,7 @@ mod tests { 3., 1., 5., 9., 1., 6., 5., 4., 2., 1., 2., 1., 1., 4., 7., 2., ]); - let exp_out1 = cx.tensor::>(); + let exp_out1 = cx.tensor::>(); exp_out1.set(vec![ 3.9600, -0.3300, -1.7800, 4.0400, 1.5300, 0.2900, 2.8700, 3.0000, 0.9600, -1.8700, 4.5900, 3.9700, 1.2800, 1.1800, 3.7800, 2.8500, 0.5500, 0.5600, 3.9800, 1.3200, @@ -370,17 +318,7 @@ mod tests { exp_out1.retrieve(); - let model: Conv2D< - CHANNELS_IN, - CHANNELS_OUT, - KERNELX, - KERNELY, - STRIDEX, - STRIDEY, - DILATIONX, - DILATIONY, - CHANNELS_IN_TIMES_KERNELX_KERNELY, - > = Conv2D::initialize(&mut cx); + let model: Conv2D = Conv2D::initialize(&mut cx); model.weight.set(vec![ 0.1600, 0.2000, 0.1900, -0.1100, 0.0100, -0.0300, -0.1200, -0.0800, -0.1300, -0.0300, 0.1600, -0.1700, -0.0000, 0.1900, 0.1300, 0.0300, -0.1500, 0.0900, 0.0100, 0.0200, @@ -389,7 +327,7 @@ mod tests { ]); let out1 = model - .forward::(inp1) + .forward::(inp1) .retrieve(); cx.execute();