diff --git a/crates/luminal_training/src/autograd.rs b/crates/luminal_training/src/autograd.rs index 42470d8b..f5af799e 100644 --- a/crates/luminal_training/src/autograd.rs +++ b/crates/luminal_training/src/autograd.rs @@ -199,14 +199,33 @@ fn add_grad( grad.shape.indexes = new_indexes; // Undo expands (sum reduce) + let mut min_idx_removed: Option = None; + let mut min_idx_removed_amount: Option = None; + // TODO: the rev() may no longer be required for i in fwd.shape.indexes.into_iter().rev() { if fwd.shape.fake[i] { + min_idx_removed = if let Some(prev) = min_idx_removed { + Some(prev.min(i)) + } else { + Some(i) + }; + let min_idx_removed = min_idx_removed.unwrap_or_default(); + let i_diff = if i > min_idx_removed { + min_idx_removed_amount.unwrap_or_default() + } else { + 0 + }; grad.id = graph - .add_op(SumReduce(i)) + .add_op(SumReduce(i - i_diff)) .input(grad.id, 0, grad.shape) .finish(); - grad.shape.remove_dim(i); + grad.shape.remove_dim(i - i_diff); grad.shape = grad.shape.contiguous(); + min_idx_removed_amount = if let Some(prev) = min_idx_removed_amount { + Some(prev + 1) + } else { + Some(1) + }; } } @@ -552,4 +571,35 @@ mod tests { .as_vec(), ); } + + #[test] + fn test_add_grad_decreasing_idx_r1() { + let mut cx = Graph::new(); + let a: GraphTensor> = cx.tensor(); + let a: GraphTensor> = a.expand::<_, LAxes2<0, 1>>(); + let a: GraphTensor> = a.permute::<_, LAxes3<2, 1, 0>>(); + assert_eq!(&a.shape.fake[..], &[false, true, true]); // has multiple fake dimensions + assert_eq!(&a.shape.indexes[..], &[0, 2, 1]); // not strictly increasing + + // note: this tests the case when the rev indexes may decrease (0 <- 2) after increasing (2 <- 1) + + let loss: GraphTensor = a.sum_reduce(); + let _grads = cx.compile(Autograd::new(vec![a.id], loss), ()); + } + + #[test] + fn test_add_grad_decreasing_idx_r2() { + let mut cx = Graph::new(); + let a: GraphTensor> = cx.tensor(); + let a: GraphTensor> = a.expand::<_, LAxes3<1, 2, 3>>(); + let a: GraphTensor> = a.permute::<_, LAxes5<4, 1, 0, 3, 2>>(); + assert_eq!(&a.shape.fake[..], &[false, false, true, true, true]); // has multiple fake dimensions + assert_eq!(&a.shape.indexes[..], &[1, 2, 0, 4, 3]); // not strictly increasing + + // note: the difference in this test to test_add_grad_decreasing_idx_r1 + // is that the rev indexes may increase (2 <- 0) after it has decreased (0 <- 4) after it has increased (4 <- 3) + + let loss: GraphTensor = a.sum_reduce(); + let _grads = cx.compile(Autograd::new(vec![a.id], loss), ()); + } }