Skip to content

Commit

Permalink
conv3d dim. missmatch resolved
Browse files Browse the repository at this point in the history
  • Loading branch information
NewBornRustacean committed May 3, 2024
1 parent ffb2baa commit a73c2b3
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions crates/luminal_nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pub struct Conv3D<
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize
> {
pub weight: GraphTensor<R5<CH_OUT, CH_IN, KERNELX, KERNELY, KERNELZ>>,
Expand All @@ -205,6 +206,7 @@ impl<
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize
> InitModule
for Conv3D<
Expand All @@ -219,6 +221,7 @@ impl<
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
Expand Down Expand Up @@ -247,6 +250,7 @@ impl<
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
> SerializeModule
for Conv3D<
Expand All @@ -261,6 +265,7 @@ impl<
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
Expand All @@ -281,6 +286,7 @@ impl<
const DILATIONX: usize,
const DILATIONY: usize,
const DILATIONZ: usize,
const DIMX_TIMES_KERNELX: usize,
const DIMX_TIMES_DIMY_DIMZ_OUT: usize,
>
Conv3D<
Expand All @@ -295,6 +301,7 @@ impl<
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
>
{
Expand All @@ -315,20 +322,29 @@ impl<
STRIDEY.into(),
DILATIONY
)
.permute::<_, Axes5<0, 2, 4, 1, 3>>()
.pool_last_dim::<R6<CH_IN, DIMY_OUT, KERNELY, DIMZ_OUT, DIMX_OUT, KERNELX>>(
.permute::<_, Axes5<0, 2, 3, 4, 1>>()
.pool_last_dim::<R6<CH_IN, DIMY_OUT, DIMZ_OUT, KERNELY, DIMX_OUT, KERNELX>>(
KERNELX.into(),
STRIDEX.into(),
DILATIONX
)
.permute::<_, Axes6<0, 5, 2, 4, 3, 1>>()
.reshape::<R4<CH_IN, KERNELX, KERNELY, DIMX_TIMES_DIMY_DIMZ_OUT>>()
.pool_last_dim::<R5<CH_IN, KERNELX, KERNELY, DIMX_TIMES_DIMY_DIMZ_OUT, KERNELZ>>(
.dyn_reshape::<(Const<CH_IN>, Dyn<'-'>)>(vec![
CH_IN.into(),
DIMZ_OUT.into(),
KERNELY.into(),
(DIMX_OUT * KERNELX).into(),
DIMY_IN.into()
]);

let last_pool = input_pooled
.pool_last_dim::<R6<CH_IN, DIMZ_OUT, KERNELY, DIMX_TIMES_KERNELX, DIMY_IN, KERNELZ>>(
KERNELZ.into(),
STRIDEZ.into(),
DILATIONZ
)
.permute::<_, Axes5<0, 4, 1, 2, 3>>()
.permute::<_, Axes6<0, 2, 5, 3, 1, 4>>();

let reshaped = last_pool
.dyn_reshape::<(_, Dyn<'-'>)>(vec![
(CH_IN * KERNELX * KERNELY * KERNELZ).into(),
DIMX_TIMES_DIMY_DIMZ_OUT.into(),
Expand All @@ -339,7 +355,7 @@ impl<
CH_OUT.into(),
(CH_IN * KERNELX * KERNELY * KERNELZ).into(),
])
.matmul(input_pooled)
.matmul(reshaped)
.reshape::<R4<CH_OUT, DIMX_OUT, DIMY_OUT, DIMZ_OUT>>()
}
}
Expand Down Expand Up @@ -524,8 +540,11 @@ mod tests {
const DIMX_OUT: usize = ((DIMX_IN - (DILATIONX + 1) * (KERNELX - 1) - 1) / STRIDEX) + 1;
const DIMY_OUT: usize = ((DIMY_IN - (DILATIONY + 1) * (KERNELY - 1) - 1) / STRIDEY) + 1;
const DIMZ_OUT: usize = ((DIMZ_IN - (DILATIONZ + 1) * (KERNELZ - 1) - 1) / STRIDEZ) + 1;
const DIMX_TIMES_KERNELX: usize = DIMX_OUT * KERNELX;
const DIMX_TIMES_DIMY_DIMZ_OUT:usize = DIMX_OUT * DIMY_OUT * DIMZ_OUT;

println!("DIMX_OUT, DIMY_OUT , DIMZ_OUT: {:?}, {:?}, {:?}",DIMX_OUT, DIMY_OUT , DIMZ_OUT);

let inp1 = cx.tensor::<R4<CH_IN, DIMX_IN, DIMY_IN, DIMZ_IN>>();
inp1.set(vec![
// Example input data (5 channels, 2x3x5 volume)
Expand Down Expand Up @@ -558,6 +577,7 @@ mod tests {
DILATIONX,
DILATIONY,
DILATIONZ,
DIMX_TIMES_KERNELX,
DIMX_TIMES_DIMY_DIMZ_OUT,
> = Conv3D::initialize(&mut cx);
let weights = vec![
Expand Down

0 comments on commit a73c2b3

Please sign in to comment.