Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk derivative (big Jacobian) in MLX #1927

Open
sck-at-ucy opened this issue Mar 5, 2025 · 4 comments
Open

Bulk derivative (big Jacobian) in MLX #1927

sck-at-ucy opened this issue Mar 5, 2025 · 4 comments

Comments

@sck-at-ucy
Copy link

I’m developing a physics‐informed neural network (PINN) in MLX that requires computing second derivatives with respect to the input coordinates (e.g.,

(r, z)

across a batch of collocation points. Right now, I’m using the “index trick”: calling mx.grad(...) inside a Python loop for each point and dimension, which works with compilation globally disabled but fails under MLX’s compiled (JIT) mode with “Cannot vjp primitive” or similar errors.

My question: Does MLX have (or plan to add) a “bulk derivative” or “big Jacobian” function, akin to JAX’s jacfwd / jacrev, that computes

∂Y/∂X


in one shot? For example, if

X is (N, d)

and the network outputs Y with shape

(N, k )

we’d want a direct call that returns the Jacobian of all Y entries w.r.t.\ all X entries in a compiled‐friendly manner (without the repeated single‐point loops). That might avoid the “Cannot vjp sum” issues I see, since a single large derivative pass could preserve the compiled graph more cleanly.

In short:
Use case: PDE derivatives in PINNs (or other operator methods) that require partial derivatives wrt each input coordinate across many points.

Problem (suspected): repeated “for i in range(N): mx.grad(...)” triggers “Cannot vjp” or “Not implemented” errors under compiled mode.

Potential solution: a built‐in “bulk” or “batched” derivative function that merges the loops internally and preserves the AD graph for compilation.

Question: Is there a function or planned feature in MLX that does this (or can approximate it)? And if so, is it likely to fix the “Cannot vjp sum” errors we see in compiled mode?

@awni I look forward to your wisdom and advice/(roadmap info) on full PDE autodiff in MLX’s compiled environment.

@awni
Copy link
Member

awni commented Mar 5, 2025

@sck-at-ucy take a look at this discussion #154

Summary there:

We don't have a plan to add them, but we could if needed. It's not so difficult to implement them in terms of mx.vjp, mx.jvp and mx.vmap. Here's a quick implementation for each you could use for now:

import mlx.core as mx

def jacrev(f):
    def jacfn(x):
        # Needed for the size of the output
        y = f(x)
        def vjpfn(cotan):
            return mx.vjp(f, (x,), (cotan,))[1][0]
        return mx.vmap(vjpfn, in_axes=0)(mx.eye(len(y)))
    return jacfn

def jacfwd(f):
    def jacfn(x):
        def jvpfn(tan):
            return mx.jvp(f, (x,), (tan,))[1][0]
        return mx.vmap(jvpfn, in_axes=0)(mx.eye(len(x)))
    return jacfn


def hessian(f):
    def hessfn(x):
        def hvp(tan):
          return mx.jvp(mx.grad(f), (x,), (tan,))[1][0]
        return mx.vmap(hvp, in_axes=0)(mx.eye(len(x)))
    return hessfn



print(jacrev(mx.sin)(mx.array([1.0, 2.0, 3.0])))
print(jacfwd(mx.sin)(mx.array([1.0, 2.0, 3.0])))

@awni
Copy link
Member

awni commented Mar 5, 2025

Python loop for each point and dimension, which works with compilation globally disabled but fails under MLX’s compiled (JIT) mode with “Cannot vjp primitive” or similar errors.

This seems like a bug somewhere to me, possibly in MLX. Are you able to share something that reproduces it?

@sck-at-ucy
Copy link
Author

Yes I would be happy to share the code. Will do that after finishing teaching tonight and cleaning the code. If indeed this is a bug (either mine or possibly in MLX) it would make be so happy because it would remove a main obstacle I have been facing with autograd for PDEs.

@sck-at-ucy
Copy link
Author

So the code is a bit long. It makes an attempt to use the implementations you suggested above.

The motivation for what I am trying to do is to be able to compute partial differential operators over the domain instead of point-wise. This opens up the possibilities to replace the MLP with something a bit more sophisticated.

The example I share, however, uses only the MLP to keep things simple.

The PDE is for fluid flow, so it involves the Laplacian for the viscous terms. That's where things go wrong. In computing the second derivatives.

I might be doing something stupid that I cannot see, but it also looks like there might be an internal MLX issue in how it tries to unify with internal reshape logic during mx.vjp/mx.vmap.

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
#from Lamb_optimizer import Lamb

###############################################################################
# 1) PDE parameters
###############################################################################
r_in = 1.0
r_out= 2.0
L     = 30.0
rho   = 1.0
nu    = 1.0e-2

###############################################################################
# 2) jacrev and double_jacrev with shape checks
###############################################################################
def jacrev(f):
    """
    Reverse-mode Jacobian for f: R^n -> R^m => shape(m,n).
    """
    def jacfn(x):
        print(f"[jacfn] jacrev => x.shape={x.shape}")
        y = f(x)  # expect shape(m,)
        # shape check or debug print
        print(f"[jacfn] jacrev => y.shape={y.shape}")
        m = y.shape[0]
        # we do vmap(...) over an (m,m) identity
        cotangent_eye = mx.eye(m)  # shape(m,m)
        def vjpfn(cotan):
            # shape(cotan)= (m,)
            # shape result => (n,)
            out = mx.vjp(f, (x,), (cotan,))[1][0]
            print(f"[jacfn] vjpfn => out.shape={out.shape}")
            return out

        J = mx.vmap(vjpfn, in_axes=0)(cotangent_eye) # shape(m,n)
        # shape check
        n = x.shape[0]
        print(f"[jacfn] jacrev => J.shape={J.shape}")
        return J
    return jacfn

def double_jacrev(f):
    base_jac = jacrev(f)      # first derivative
    def h(x):
        # shape check in the intermediate function
        j = base_jac(x)       # shape(m,n) ideally
        print(f"[h] base_jac => j.shape={j.shape}")
        m, n = j.shape
        print(f'm,n: {m} {n}')
        out_flat = mx.reshape(j, (m*n,))
        print(f"[h] flattened => {out_flat.shape}")
        return out_flat

    def ddjf(x):
        # second derivative
        j2 = jacrev(h)(x)
        print(f"[ddjf] j2.shape={j2.shape}")
        # now we expect j2 => shape(m*n,n). Then reshape => (m,n,n).
        m = f(x).shape[0]    # e.g. 3*N
        n = x.shape[0]       # e.g. 4*N
        out = mx.reshape(j2, (m,n,n))
        print(f"[ddjf] final Hessian => out.shape={out.shape}")
        return out

    return ddjf


###############################################################################
# 3) MLP model
###############################################################################
class MLPBatchPINN(nn.Module):
    def __init__(self, input_dim=4, hidden_dims=[64,64,64], output_dim=3, activation=nn.silu):
        super().__init__()
        self.layers = []
        in_dim = input_dim
        for h_dim in hidden_dims:
            self.layers.append(nn.Linear(in_dim, h_dim))
            in_dim = h_dim
        self.out_layer = nn.Linear(in_dim, output_dim)
        self.activation = activation

    def __call__(self, X):
        # X shape => (N,4)
        for layer in self.layers:
            X = layer(X)
            X = self.activation(X)
        X = self.out_layer(X)
        # shape => (N,3)
        print(f'model: {X.shape}')
        return X

###############################################################################
# 4) WeightedPINN_Batch with shape checks in field_fn, pde_fn
###############################################################################
class WeightedPINN_Batch(nn.Module):
    """
    PDE with second derivatives:
    field_fn: (4N,)->(3N,)
    jac_fn => shape(3N,4N)
    hess_fn => shape(3N,4N,4N)
    """
    def __init__(self, core_model, alpha_min=0.1, alpha_max=5.0):
        super().__init__()
        self.core_model = core_model
        self.logit_alpha_cont = mx.array(0.0)
        self.logit_alpha_mom  = mx.array(0.0)
        self.alpha_min = alpha_min
        self.alpha_max = alpha_max

        self.jac_fn  = None
        self.hess_fn = None

    def alpha_cont(self):
        return self.alpha_min + (self.alpha_max-self.alpha_min)*mx.sigmoid(self.logit_alpha_cont)
    def alpha_mom(self):
        return self.alpha_min + (self.alpha_max-self.alpha_min)*mx.sigmoid(self.logit_alpha_mom)

    def forward_batch(self, X):
        return self.core_model(X)  # shape(N,3)

    def field_fn(self, x_flat):
        """
        x_flat => shape(4N,)
        -> (N,4) => forward => (N,3) => flatten => (3N,)
        with shape checks
        """
        # 1) ensure x_flat is 1D
        assert x_flat.ndim == 1, f"[field_fn] x_flat must be 1D, got shape{ x_flat.shape}"

        total_size = x_flat.shape[0]
        # must be multiple of 4
        assert total_size%4==0, f"[field_fn] x_flat size {total_size} not multiple of 4"
        N = total_size//4

        # reshape => (N,4)
        X = mx.reshape(x_flat, (N,4))
        # forward => (N,3)
        out = self.core_model(X)
        print(f"[field_fn] out.shape={out.shape}")
        #assert out.shape==(N,3), f"[field_fn] expected (N,3)=({N},3), got {out.shape}"

        # flatten => shape(3N,)
        out_flat = mx.reshape(out, (3*N,))
        print(f"[field_fn] out_flat.shape={out_flat.shape}")
        #assert out_flat.shape==(3*N,), f"[field_fn] expected(3N,) => {(3*N,)}, got{out_flat.shape}"
        return out_flat

    def build_derivatives(self):
        """Construct first & second derivative closures with shape checks."""
        self.jac_fn  = jacrev(self.field_fn)
        self.hess_fn = double_jacrev(self.field_fn)

    def pde_fn(self, x_flat):
        """
        returns shape(3N,) PDE
        shape checks for each step
        """
        total_size = x_flat.shape[0]
        assert total_size%4==0, f"[pde_fn] x_flat size {total_size} not multiple of 4"
        N = total_size//4

        # 1) field => shape(3N,)
        out_flat = self.field_fn(x_flat)
        print(f'out_flat {out_flat.shape}')
        assert out_flat.shape==(3*N,), f"[pde_fn] out_flat must be (3N,) => {(3*N,)}, got {out_flat.shape}"

        # 2) jac => shape(3N,4N)
        J = self.jac_fn(x_flat)
        print(f'J {J.shape}')
        expected_jac_shape = (3*N,4*N)
        assert J.shape==expected_jac_shape, f"[pde_fn] J => expect{expected_jac_shape}, got {J.shape}"

        # 3) hess => shape(3N,4N,4N)
        H = self.hess_fn(x_flat)
        print(f"[pde_fn] H.shape={H.shape}")
        expected_hess_shape = (3*N,4*N,4*N)
        assert H.shape==expected_hess_shape, f"[pde_fn] H => expect{expected_hess_shape}, got {H.shape}"

        # we do PDE logic => continuity + momentum
        # i_r_rows => (0,3,6,...)
        i_r_rows = mx.arange(0,3*N,3)  # shape(N,)
        i_z_rows = i_r_rows+1
        i_p_rows = i_r_rows+2

        r_cols   = mx.arange(0,4*N,4) # shape(N,)
        z_cols   = r_cols +1

        # gather field
        u_r_val = mx.take(out_flat, i_r_rows, axis=0)
        u_z_val = mx.take(out_flat, i_z_rows, axis=0)
        r_vals  = mx.take(x_flat, r_cols, axis=0)

        # gather first derivatives => shape(N,)
        # partial_J_ur => shape(N,4N)
        partial_J_ur = mx.take(J, i_r_rows, axis=0)
        # du_r/dr => shape(N,)
        du_r_dr = mx.take_along_axis(partial_J_ur,mx.expand_dims(r_cols,1), axis=1)
        du_r_dr = mx.reshape(du_r_dr,(N,))
        # du_r/dz
        du_r_dz = mx.take_along_axis(partial_J_ur,mx.expand_dims(z_cols,1), axis=1)
        du_r_dz = mx.reshape(du_r_dz,(N,))

        partial_J_uz= mx.take(J, i_z_rows, axis=0)
        du_z_dr= mx.take_along_axis(partial_J_uz,mx.expand_dims(r_cols,1), axis=1)
        du_z_dr= mx.reshape(du_z_dr,(N,))
        du_z_dz= mx.take_along_axis(partial_J_uz,mx.expand_dims(z_cols,1), axis=1)
        du_z_dz= mx.reshape(du_z_dz,(N,))

        partial_J_p= mx.take(J, i_p_rows, axis=0)
        dp_dr= mx.take_along_axis(partial_J_p,mx.expand_dims(r_cols,1), axis=1)
        dp_dr= mx.reshape(dp_dr,(N,))
        dp_dz= mx.take_along_axis(partial_J_p,mx.expand_dims(z_cols,1), axis=1)
        dp_dz= mx.reshape(dp_dz,(N,))

        # gather second derivatives => shape(N,)
        partial_H_ur = mx.take(H, i_r_rows, axis=0) # (N,4N,4N)
        d2ur_dr2 = gather_2D(partial_H_ur, r_cols, r_cols)
        d2ur_dz2 = gather_2D(partial_H_ur, z_cols, z_cols)

        partial_H_uz = mx.take(H, i_z_rows, axis=0)
        d2uz_dr2 = gather_2D(partial_H_uz, r_cols, r_cols)
        d2uz_dz2 = gather_2D(partial_H_uz, z_cols, z_cols)

        # Continuity => (1/r)(u_r + r du_r/dr) + du_z/dz
        cont = (1.0/r_vals)*(u_r_val + r_vals*du_r_dr) + du_z_dz

        # radial => adv + press + nu*(lapl(u_r)-u_r/r^2)
        lapl_ur = d2ur_dr2 + (1.0/r_vals)*du_r_dr + d2ur_dz2
        minus_ur_r2 = -u_r_val/(r_vals*r_vals)
        adv_r = u_r_val*du_r_dr + u_z_val*du_r_dz
        press_r= -(1.0/rho)*dp_dr
        visc_r = nu*(lapl_ur + minus_ur_r2)
        r_mom = adv_r + press_r + visc_r

        # axial => adv + press + nu*lapl(u_z)
        lapl_uz= d2uz_dr2 + (1.0/r_vals)*du_z_dr + d2uz_dz2
        adv_z = u_r_val*du_z_dr + u_z_val*du_z_dz
        press_z= -(1.0/rho)*dp_dz
        visc_z= nu*lapl_uz
        z_mom= adv_z + press_z + visc_z

        pde_array= mx.stack([cont, r_mom, z_mom], axis=1)
        pde_flat= mx.reshape(pde_array, (3*N,))
        print(f'pde_flat {pde_flat.shape}')
        return pde_flat

    def pde_loss(self, X):
        N= X.shape[0]
        x_flat= mx.reshape(X,(4*N,))
        PDEvals= self.pde_fn(x_flat) # (3N,)

        PDEvals_resh= mx.reshape(PDEvals,(N,3))
        cont_part= PDEvals_resh[:,0]
        r_part   = PDEvals_resh[:,1]
        z_part   = PDEvals_resh[:,2]
        cont_loss= mx.mean(cont_part**2)
        mom_loss = mx.mean(r_part**2+z_part**2)

        return self.alpha_cont()*cont_loss + self.alpha_mom()*mom_loss

    def boundary_loss(self, X_bc):
        print("[boundary_loss] X_bc.shape=", X_bc.shape)
        out_bc = self.core_model(X_bc)
        print("[boundary_loss] out_bc.shape=", out_bc.shape)
        bc_vals= out_bc[:,0:2]  # (u_r,u_z)
        return mx.mean(bc_vals**2)

###############################################################################
# 5) gather_2D => shape(N,) picking row_idx[i],col_idx[i] from (N,A,B).
###############################################################################
def gather_2D(tensor_3d, row_idx, col_idx):
    N, A, B = tensor_3d.shape
    tens_flat= mx.reshape(tensor_3d, (N, A*B))
    linear_idx= row_idx*B+ col_idx
    lin_idx_2d= mx.expand_dims(linear_idx,1) # shape(N,1)
    out_2d= mx.take_along_axis(tens_flat, lin_idx_2d, axis=1)
    out_1d= mx.reshape(out_2d,(N,))
    return out_1d

###############################################################################
# 6) The Train Function w/ shape checks
###############################################################################
def train_example(num_epochs=200):
    #mx.disable_compile()

    # MLP
    net_core= MLPBatchPINN(input_dim=4, hidden_dims=[64,64,64], output_dim=3)
    pinn= WeightedPINN_Batch(net_core)

    # build first & second derivatives
    pinn.build_derivatives()

    # domain
    N= 4  # let's do small to avoid big Hessian
    r_= mx.random.uniform(r_in, r_out, (N,))
    z_= mx.random.uniform(0.0, L, (N,))
    X_interior= mx.stack([r_,z_, mx.zeros_like(r_), mx.zeros_like(r_)], axis=1)
    print(f'[X_interior] {X_interior.shape}')

    M=5
    r_bc= mx.full((M,), r_in)
    z_bc= mx.random.uniform(0.0, L, (M,))
    X_bc= mx.stack([r_bc, z_bc, mx.zeros_like(r_bc), mx.zeros_like(r_bc)], axis=1)
    print(f'[X_bc] {X_bc.shape}')

    #optimizer= Lamb(weight_decay=0.005, learning_rate=1e-3, eps=1e-12)
    optimizer = optim.Adam(learning_rate=1e-3)

    # state => [pinn.state, optimizer.state, X_interior, X_bc]
    state= [pinn.state, optimizer.state, X_interior, X_bc]
    mx.eval(state)

    @partial(mx.compile, inputs=state, outputs=state)
    def train_step():
        pinn_state, opt_state, X_int, X_bc_ = state
        print(f'X_int.shape: {X_int.shape}, X_bc.shape {X_bc.shape}')

        def loss_fn():
            pde_l= pinn.pde_loss(X_int)   #<<<--- Commenting out this & returning only '10.0 * bc_l' runs
            bc_l= pinn.boundary_loss(X_bc_)
            return pde_l + 10.0* bc_l
            #return 10.0 * bc_l    # you need to also comment out the pde_l line to run w/o issue

        loss_val, grads= nn.value_and_grad(pinn, loss_fn)()
        optimizer.update(pinn, grads)

        new_pinn_state= pinn.state
        new_opt_state= optimizer.state
        return loss_val, (new_pinn_state, new_opt_state, X_int, X_bc_)

    for epoch in range(num_epochs):
        loss_val, new_state= train_step()
        if (epoch+1)%50==0:
            print(f"Epoch {epoch+1}, loss={float(loss_val):.6f}")
        mx.eval(new_state)

    print("Done!")
    return pinn


if __name__=="__main__":
    trained= train_example(num_epochs=200)

'''

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants