Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward differentiation nan gradient, nonsingular computation #1489

Open
unalmis opened this issue Dec 22, 2024 · 1 comment
Open

Forward differentiation nan gradient, nonsingular computation #1489

unalmis opened this issue Dec 22, 2024 · 1 comment
Labels
AD related to automatic differentation optimization Adding or improving optimization methods question Further information is requested

Comments

@unalmis
Copy link
Collaborator

unalmis commented Dec 22, 2024

Due to recent numerical changes in JAX or maybe optimization routines, there is a now an nan leak when optimizing with forward mode differentiation for some quantities. The issue does not occur with reverse mode differentiation.

After debugging, I have identified the the source of the nan gradient generation. It is peculiar.

  • Nan only occurs in forward mode.
  • All computations are well-behaved. In particular, let $f_i$ be the set of operations to compute that when composed define the computation of interest $F$. All $f_i$ output finite values with numerical buffer between the boundary of the domain where the $f_i$ are defined. That is, there is no computation like $f_i = \arccos(1 - \epsilon)$ with $\epsilon \geq 0$ small, which would be problematic if epsilon was small because the gradient $\nabla f_i$ would then become nan .
  • The source of the nan is from a non-singular operation.
  • It is on iteration 0 of the optimization.

The nan is generated by the v_tau term here and here.

  1. Removing v_tau from that line alone avoids the issue.
  2. Redefining the integrand _v_tau to be 1, so that v_tau is computed as the length of the bounce integral $\int d\ell$ the nan persists.

Therefore, the nan gradient in forward mode arises from the computation of $d\ell /d \zeta = B / B^\zeta$. I have confirmed that this quantity is well-behaved. Also, this can be computed in other routines at the exact same points with $\epsilon_{effective}$, and there the issue does not occur.

FYI, I have found the JAX nan debugging tools useful before, but for this problem they are not. Maybe related to jax-ml/jax#25519.

Below is MWE.

eq0 = get("W7-X")
eq1 = eq0.copy()
k = 2  # which modes to unfix
print()
print("---------------------------------------")
print(f"Optimizing boundary modes M, N <= {k}")
print("---------------------------------------")
modes_R = np.vstack(
    (
        [0, 0, 0],
        eq1.surface.R_basis.modes[np.max(np.abs(eq1.surface.R_basis.modes), 1) > k, :],
    )
)
modes_Z = eq1.surface.Z_basis.modes[np.max(np.abs(eq1.surface.Z_basis.modes), 1) > k, :]
constraints = (
    ForceBalance(eq=eq1),
    FixBoundaryR(eq=eq1, modes=modes_R),
    FixBoundaryZ(eq=eq1, modes=modes_Z),
    FixPressure(eq=eq1),
    FixIota(eq=eq1),
    FixPsi(eq=eq1),
)
grid = LinearGrid(
    rho=np.linspace(0.2, 1, 3), M=eq1.M_grid, N=eq1.N_grid, NFP=eq1.NFP, sym=False
)
objective = ObjectiveFunction(
    (
        # Both Gamma_c Nemov and Gamma_c Velasco are affected. Effective ripple is not.
        GammaC(
           # pick anything
            eq1,
            grid=grid,
            X=16,
            Y=32,
            Y_B=128,
            num_transit=5,
            num_well=20 * 10,
            num_quad=8,
            num_pitch=10,
            deriv_mode="fwd"
        ),
    )
)
optimizer = Optimizer("proximal-lsq-exact")
(eq1,), _ = optimizer.optimize(
    eq1,
    objective,
    constraints,
    ftol=1e-4,
    xtol=1e-6,
    gtol=1e-6,
    maxiter=1,  # increase maxiter to 50 for a better result
    verbose=3,
    options={"initial_trust_ratio": 2e-3},
)
@unalmis unalmis added bug optimization Adding or improving optimization methods AD related to automatic differentation labels Dec 22, 2024
@unalmis unalmis changed the title Forward mode differentiation leaks nan gradient from nonsingular computation Forward mode differentiation nan gradient, nonsingular computation Dec 22, 2024
@unalmis unalmis changed the title Forward mode differentiation nan gradient, nonsingular computation Forward differentiation nan gradient, nonsingular computation Dec 22, 2024
@mattjj
Copy link

mattjj commented Jan 1, 2025

What backend is this on? (CPU, GPU, ...)

If this was indeed a regression due to some change in JAX or XLA (or some lower-level compiler like LLVM), I can try to bisect it against Google's internal monorepo. But to do that it would be really helpful to have a minimal runnable example (ie including imports etc). EDIT: and the fewer dependencies, the better! That is, if we could get a pure-JAX repro, that'd be easiest to bisect on.

@dpanici dpanici added the question Further information is requested label Jan 6, 2025
@unalmis unalmis removed the bug label Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AD related to automatic differentation optimization Adding or improving optimization methods question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants