Forward differentiation nan gradient, nonsingular computation #1489
Labels
AD
related to automatic differentation
optimization
Adding or improving optimization methods
question
Further information is requested
Due to recent numerical changes in JAX or maybe optimization routines, there is a now an
nan
leak when optimizing with forward mode differentiation for some quantities. The issue does not occur with reverse mode differentiation.After debugging, I have identified the the source of the nan gradient generation. It is peculiar.
nan
.nan
is from a non-singular operation.The nan is generated by the
v_tau
term here and here.v_tau
from that line alone avoids the issue._v_tau
to be 1, so thatv_tau
is computed as the length of the bounce integralnan
persists.Therefore, the nan gradient in forward mode arises from the computation of$d\ell /d \zeta = B / B^\zeta$ . I have confirmed that this quantity is well-behaved. Also, this can be computed in other routines at the exact same points with $\epsilon_{effective}$ , and there the issue does not occur.
FYI, I have found the JAX
nan
debugging tools useful before, but for this problem they are not. Maybe related to jax-ml/jax#25519.Below is MWE.
The text was updated successfully, but these errors were encountered: