Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated backprop using filtered-transformations #954

Open
SNMS95 opened this issue Feb 21, 2025 · 3 comments
Open

Truncated backprop using filtered-transformations #954

SNMS95 opened this issue Feb 21, 2025 · 3 comments
Labels
question User queries

Comments

@SNMS95
Copy link

SNMS95 commented Feb 21, 2025

Hi everyone,

I am working on doing meta-learning and wanted to implement truncated backprop to estimate the meta-level gradients.

import jax.flatten_util
import optax
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp

NSTEPS = 500
KIND = 'lax'


def inner_loss(params):
    """Loss function"""
    flat_params = jax.flatten_util.ravel_pytree(params)[0]
    return jnp.sum(flat_params**2)


@eqx.filter_jit
def inner_optimization(params, opt_state, num_steps):

    def inner_opt_step(carry, _):
        """Inner optimization function"""
        params, opt_state = carry
        # Get the gradients
        inner_grads = eqx.filter_grad(inner_loss)(params)
        # Update the parameters
        updates, opt_state = inner_opt.update(inner_grads, opt_state)
        params = eqx.apply_updates(params, updates)
        return (params, opt_state), None

    # make a scan
    init = (params, opt_state)
    (final_params, final_opt_state), _ = eqxi.scan(inner_opt_step, init,
                                                   None, length=num_steps, kind=KIND)
    return final_params, final_opt_state

def truncated_inner_optimization(params, opt_state, num_steps, num_steps_truncated):
    """Truncated inner optimization"""

    @eqx.filter_custom_jvp
    def wrapped_inner_optimization(params, opt_state, num_steps):
        final_params, final_opt_state = inner_optimization(
            params, opt_state, num_steps)
        return final_params, final_opt_state

    @wrapped_inner_optimization.def_jvp
    def wrapped_inner_optimization_jvp(primals, tangents):
        """Truncated inner optimization"""
        primal_out = wrapped_inner_optimization(*primals)
        tangent_out = tangents[:2]
        return primal_out, tangent_out

    # Run the inner optimization with wrapped fn
    final_params, final_opt_state = eqx.filter_jit(
        wrapped_inner_optimization)(params, opt_state, num_steps_truncated)

    # Run the remaining steps
    final_params, final_opt_state = inner_optimization(
        final_params, final_opt_state, num_steps - num_steps_truncated)
    return final_params, final_opt_state

def outer_loss2(params):
    """Outer loss function"""
    # Run optimization
    final_params, _ = truncated_inner_optimization(
        params, inner_opt_state, NSTEPS, NSTEPS - 100)
    # Compute the loss
    flat_final_params = jax.flatten_util.ravel_pytree(final_params)[0]
    return jnp.sum(flat_final_params**3)

# Test
# create a pytree for the parameters 
params = {'w': jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))}
inner_opt = optax.adam(1e-3)
inner_opt_state = inner_opt.init(params)
outer_grads = eqx.filter_grad(outer_loss2)(params)
print(outer_grads)

When I checked the memory usage, it seems that this method is working but it feels hacky.
Is there a better way to do this?

@patrick-kidger
Copy link
Owner

This looks pretty good to me! It's probably worth tweaking things slightly to have a single JIT wrapping the whole thing (including your final outer_grads = eqx.filter_grad(...), but otherwise I think this looks about as good as it gets :)

@patrick-kidger patrick-kidger added the question User queries label Feb 21, 2025
@SNMS95
Copy link
Author

SNMS95 commented Feb 25, 2025

Hi @patrick-kidger

This is an even more minimal example. But this errors out.

import jax
import jax.flatten_util
import jax.numpy as jnp
import equinox as eqx
from collections import namedtuple

class Container(eqx.Module):
    inner_model: eqx.Module
    outer_model: namedtuple

    def __init__(self, outer_param_dict):
        self.inner_model = eqx.nn.MLP(1, 1, 5, 3, key=jax.random.PRNGKey(0))
        # Convert dictionary to namedtuple
        self.outer_model = namedtuple('OuterModel', outer_param_dict.keys())(**outer_param_dict)

    def __call__(self, x):
        return self.inner_model(x)
    
def new_inner_model(model):
    """Update the inner model"""
    inner_params, inner_static = eqx.partition(model, inner_filter)
    val = eqx.combine(inner_params, inner_static)(jnp.array([1.0]))
    inner_params = jax.tree.map(lambda x: x*val, inner_params)
    return eqx.combine(inner_params, inner_static)

def wrapper(fn):
    """Wrap to do truncated backprop"""

    @eqx.filter_custom_jvp
    def wrapped_fn(*args):
        return fn(*args)
    
    @wrapped_fn.def_jvp
    def _jvp(primals, tangents):
        primals_out = wrapped_fn(*primals)
        return primals_out, tangents

    return wrapped_fn

new_inner_model_wrapped = wrapper(new_inner_model)

def loss(outer_params, outer_static):
    a = outer_params.outer_model.a
    b = outer_params.outer_model.b
    # Do something to the inner params
    model = eqx.combine(outer_params, outer_static)
    inner_param, inner_static = eqx.partition(model, inner_filter)
    inner_param = jax.tree.map(lambda x: x*(a + b**2), inner_param)
    model = eqx.combine(inner_param, inner_static)

    # model = new_inner_model(model)  # This works
    model = new_inner_model_wrapped(model)  # This doesn't work

    inner_param, inner_static = eqx.partition(inner_param, inner_filter)
    flat_inner_params = jax.flatten_util.ravel_pytree(inner_param)[0]
    return a + b  + jnp.sum(flat_inner_params) + a + b

# test
import jax.tree as jt
outer_param_dict = {'a': jnp.array(1.0), 'b': jnp.array(2.0)}
container = Container(outer_param_dict)

base_filter = jt.map(lambda _: False, container)
def _bilevel_igor_filters(model):
    task_filter = jt.map(eqx.is_array, model.inner_model)
    task_filter = eqx.tree_at(
        lambda tree: tree.inner_model, base_filter, task_filter)
    meta_filter = jt.map(eqx.is_array_like, model.outer_model)
    meta_filter = eqx.tree_at(
        lambda tree: tree.outer_model, base_filter, meta_filter)
    return task_filter, meta_filter

inner_filter, outer_filter = _bilevel_igor_filters(container)

outer_params, outer_static = eqx.partition(container, outer_filter)
eqx.filter_grad(loss)(outer_params, outer_static).outer_model

Do you know what should be done to include this?

@patrick-kidger
Copy link
Owner

The output of a eqx.filter_custom_jvp must still consist only of JAX types. In this case you're returning the output of new_inner_model, which still has its static componets.

(I think)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants