From 3d2dd74fde5e7037589fe4d91b884a4f90bbe191 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 20 Jun 2024 12:31:45 -0400 Subject: [PATCH 1/4] add test (failing) --- crates/luminal_training/src/autograd.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/crates/luminal_training/src/autograd.rs b/crates/luminal_training/src/autograd.rs index 42470d8b..f2a6934d 100644 --- a/crates/luminal_training/src/autograd.rs +++ b/crates/luminal_training/src/autograd.rs @@ -552,4 +552,20 @@ mod tests { .as_vec(), ); } + + #[test] + fn test_add_grad_decreasing_idx() { + let mut cx = Graph::new(); + let a: GraphTensor> = cx.tensor(); + let a: GraphTensor> = a.expand::<_, LAxes2<0, 1>>(); + let a: GraphTensor> = a.permute::<_, LAxes3<2, 1, 0>>(); + // a.shape.fake = [false, true, true] + // a.shape.indexes = [0, 2, 1] // note that the idx isn't necessarily increasing (0,1,2) + let b: GraphTensor> = cx.tensor(); + let weights = vec![a.id, b.id]; + + let m: GraphTensor> = a * b; + let loss: GraphTensor = m.sum_reduce(); + let _grads = cx.compile(Autograd::new(weights, loss), ()); + } } From 0ac8515733e4f24e44b9c3dca2fd300e62838da0 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 20 Jun 2024 15:29:46 -0400 Subject: [PATCH 2/4] simplify and add a new test --- crates/luminal_training/src/autograd.rs | 33 ++++++++++++++++++------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/crates/luminal_training/src/autograd.rs b/crates/luminal_training/src/autograd.rs index f2a6934d..892b4f39 100644 --- a/crates/luminal_training/src/autograd.rs +++ b/crates/luminal_training/src/autograd.rs @@ -554,18 +554,33 @@ mod tests { } #[test] - fn test_add_grad_decreasing_idx() { + fn test_add_grad_decreasing_idx_r1() { let mut cx = Graph::new(); let a: GraphTensor> = cx.tensor(); let a: GraphTensor> = a.expand::<_, LAxes2<0, 1>>(); let a: GraphTensor> = a.permute::<_, LAxes3<2, 1, 0>>(); - // a.shape.fake = [false, true, true] - // a.shape.indexes = [0, 2, 1] // note that the idx isn't necessarily increasing (0,1,2) - let b: GraphTensor> = cx.tensor(); - let weights = vec![a.id, b.id]; - - let m: GraphTensor> = a * b; - let loss: GraphTensor = m.sum_reduce(); - let _grads = cx.compile(Autograd::new(weights, loss), ()); + assert_eq!(&a.shape.fake[..], &[false, true, true]); // has multiple fake dimensions + assert_eq!(&a.shape.indexes[..], &[0, 2, 1]); // not strictly increasing + + // note: this tests the case when the rev indexes may decrease (2 -> 0) after increasing (1 -> 2) + + let loss: GraphTensor = a.sum_reduce(); + let _grads = cx.compile(Autograd::new(vec![a.id], loss), ()); + } + + #[test] + fn test_add_grad_decreasing_idx_r2() { + let mut cx = Graph::new(); + let a: GraphTensor> = cx.tensor(); + let a: GraphTensor> = a.expand::<_, LAxes3<1, 2, 3>>(); + let a: GraphTensor> = a.permute::<_, LAxes5<4, 1, 0, 3, 2>>(); + assert_eq!(&a.shape.fake[..], &[false, false, true, true, true]); // has multiple fake dimensions + assert_eq!(&a.shape.indexes[..], &[1, 2, 0, 4, 3]); // not strictly increasing + + // note: the difference in this test to test_add_grad_decreasing_idx_r1 + // is that the rev indexes may increase (0 -> 2) after it has decreased (4 -> 0) after it has increased (3 -> 4) + + let loss: GraphTensor = a.sum_reduce(); + let _grads = cx.compile(Autograd::new(vec![a.id], loss), ()); } } From e0196e61189a71e2e2d2e8262699e885a67f26b0 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 20 Jun 2024 15:31:47 -0400 Subject: [PATCH 3/4] try fix --- crates/luminal_training/src/autograd.rs | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/crates/luminal_training/src/autograd.rs b/crates/luminal_training/src/autograd.rs index 892b4f39..94e72124 100644 --- a/crates/luminal_training/src/autograd.rs +++ b/crates/luminal_training/src/autograd.rs @@ -199,14 +199,33 @@ fn add_grad( grad.shape.indexes = new_indexes; // Undo expands (sum reduce) + let mut min_idx_removed: Option = None; + let mut min_idx_removed_amount: Option = None; + // TODO: the rev() may no longer be required for i in fwd.shape.indexes.into_iter().rev() { if fwd.shape.fake[i] { + min_idx_removed = if let Some(prev) = min_idx_removed { + Some(prev.min(i)) + } else { + Some(i) + }; + let min_idx_removed = min_idx_removed.unwrap_or_default(); + let i_diff = if i > min_idx_removed { + min_idx_removed_amount.unwrap_or_default() + } else { + 0 + }; grad.id = graph - .add_op(SumReduce(i)) + .add_op(SumReduce(i - i_diff)) .input(grad.id, 0, grad.shape) .finish(); - grad.shape.remove_dim(i); + grad.shape.remove_dim(i - i_diff); grad.shape = grad.shape.contiguous(); + min_idx_removed_amount = if let Some(prev) = min_idx_removed_amount { + Some(prev + 1) + } else { + Some(1) + }; } } From 6913f5e9ea1953d49e32809db8c418311c3418e3 Mon Sep 17 00:00:00 2001 From: Thiago Machado Date: Thu, 20 Jun 2024 15:37:35 -0400 Subject: [PATCH 4/4] improve comments --- crates/luminal_training/src/autograd.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/luminal_training/src/autograd.rs b/crates/luminal_training/src/autograd.rs index 94e72124..f5af799e 100644 --- a/crates/luminal_training/src/autograd.rs +++ b/crates/luminal_training/src/autograd.rs @@ -581,7 +581,7 @@ mod tests { assert_eq!(&a.shape.fake[..], &[false, true, true]); // has multiple fake dimensions assert_eq!(&a.shape.indexes[..], &[0, 2, 1]); // not strictly increasing - // note: this tests the case when the rev indexes may decrease (2 -> 0) after increasing (1 -> 2) + // note: this tests the case when the rev indexes may decrease (0 <- 2) after increasing (2 <- 1) let loss: GraphTensor = a.sum_reduce(); let _grads = cx.compile(Autograd::new(vec![a.id], loss), ()); @@ -597,7 +597,7 @@ mod tests { assert_eq!(&a.shape.indexes[..], &[1, 2, 0, 4, 3]); // not strictly increasing // note: the difference in this test to test_add_grad_decreasing_idx_r1 - // is that the rev indexes may increase (0 -> 2) after it has decreased (4 -> 0) after it has increased (3 -> 4) + // is that the rev indexes may increase (2 <- 0) after it has decreased (0 <- 4) after it has increased (4 <- 3) let loss: GraphTensor = a.sum_reduce(); let _grads = cx.compile(Autograd::new(vec![a.id], loss), ());