Skip to content

Commit

Permalink
Revert "simplify"
Browse files Browse the repository at this point in the history
This reverts commit 408f995.
  • Loading branch information
evanatyourservice committed Dec 18, 2024
1 parent 46dc78e commit 568af0d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 16 deletions.
74 changes: 65 additions & 9 deletions image_classification_jax/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from optax.contrib._schedule_free import schedule_free_eval_params
import tensorflow_datasets as tfds
import tensorflow as tf
from psgd_jax import hessian_helper

from image_classification_jax.utils.imagenet_pipeline import (
create_split,
Expand Down Expand Up @@ -69,6 +70,8 @@ def run_experiment(
n_epochs: int = 150,
optimizer: optax.GradientTransformation = optax.adamw(1e-3),
compute_in_bfloat16: bool = False,
l2_regularization: float = 0.0,
randomize_l2_reg: bool = False,
apply_z_loss: bool = True,
model_type: str = "resnet18",
n_layers: int = 12,
Expand All @@ -77,6 +80,8 @@ def run_experiment(
n_empty_registers: int = 0,
dropout_rate: float = 0.0,
using_schedule_free: bool = False,
psgd_calc_hessian: bool = False,
psgd_precond_update_prob: float = 1.0,
):
"""Run an image classification experiment.
Expand Down Expand Up @@ -347,6 +352,23 @@ def loss_fn(params, batch_stats, rng, images, labels):
if apply_z_loss:
loss += z_loss(logits).mean() * 1e-4

if l2_regularization > 0:
to_l2 = []
for key, value in _sorted_items(flatten_dict(_get_params_dict(params))):
path = "/" + "/".join(key)
if "kernel" in path:
to_l2.append(jnp.linalg.norm(value))
l2_loss = jnp.linalg.norm(jnp.array(to_l2))

if randomize_l2_reg:
rng, subkey = jax.random.split(rng)
multiplier = jax.random.uniform(
subkey, dtype=jnp.float32, minval=0.0, maxval=2.0
)
l2_loss *= multiplier

loss += l2_regularization * l2_loss

return loss, (new_model_state, logits, orig_loss)

@partial(pmap, axis_name="batch", donate_argnums=(1,))
Expand All @@ -365,16 +387,50 @@ def train_step(rng, state, batch):
accuracy: float, mean accuracy.
grad_norm: float, mean gradient
"""
rng, subkey = jax.random.split(rng)
(loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params, state.batch_stats, subkey, batch["image"], batch["label"]
)
# mean gradients across devices
grads = jax.lax.pmean(grads, axis_name="batch")
if psgd_calc_hessian:
rng, subkey1, subkey2 = jax.random.split(rng, 3)

# use psgd hessian helper to calc hvp and pass into psgd
subkey1 = jax.lax.all_gather(subkey1, "batch")
# same key on all devices for random vector and precond update prob
subkey1 = subkey1[0]
(_, aux), grads, hvp, vector, update_precond = hessian_helper(
subkey1,
state.step,
loss_fn,
state.params,
loss_fn_extra_args=(
state.batch_stats,
subkey2,
batch["image"],
batch["label"],
),
has_aux=True,
preconditioner_update_probability=psgd_precond_update_prob,
)

updates, new_opt_state = optimizer.update(
grads, state.opt_state, state.params
)
grads = jax.lax.pmean(grads, axis_name="batch")
hvp = jax.lax.pmean(hvp, axis_name="batch")

updates, new_opt_state = optimizer.update(
grads,
state.opt_state,
state.params,
Hvp=hvp,
vector=vector,
update_preconditioner=update_precond,
)
else:
rng, subkey = jax.random.split(rng)
(loss, aux), grads = jax.value_and_grad(loss_fn, has_aux=True)(
state.params, state.batch_stats, subkey, batch["image"], batch["label"]
)
# mean gradients across devices
grads = jax.lax.pmean(grads, axis_name="batch")

updates, new_opt_state = optimizer.update(
grads, state.opt_state, state.params
)

new_model_state, logits, loss = aux
accuracy = jnp.mean(jnp.argmax(logits, -1) == batch["label"])
Expand Down
38 changes: 33 additions & 5 deletions image_classification_jax/run_experiment_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,42 @@
import optax
from run_experiment import run_experiment
from image_classification_jax.run_experiment import run_experiment
from psgd_jax.affine import affine


if __name__ == "__main__":
optimizer = optax.contrib.schedule_free_adamw(warmup_steps=256, b1=0.95)
base_lr = 0.001
warmup = 256
lr = optax.join_schedules(
schedules=[
optax.linear_schedule(0.0, base_lr, warmup),
optax.constant_schedule(base_lr),
],
boundaries=[warmup],
)

psgd_opt = optax.chain(
optax.clip_by_global_norm(1.0),
affine(
lr,
preconditioner_update_probability=1.0,
b1=0.0,
weight_decay=0.0,
max_size_triangular=0,
max_skew_triangular=0,
precond_init_scale=1.0,
),
)

optimizer = optax.contrib.schedule_free(psgd_opt, learning_rate=lr, b1=0.95)

run_experiment(
log_to_wandb=True,
wandb_entity="",
wandb_project="image_classification_jax",
wandb_config_update={ # extra logging info for wandb
"optimizer": "adamw",
"lr": 0.0025,
"warmup": 256,
"optimizer": "psgd_affine",
"lr": base_lr,
"warmup": warmup,
"b1": 0.95,
"schedule_free": True,
},
Expand All @@ -22,6 +46,8 @@
n_epochs=10,
optimizer=optimizer,
compute_in_bfloat16=False,
l2_regularization=0.0001,
randomize_l2_reg=False,
apply_z_loss=True,
model_type="vit",
n_layers=4,
Expand All @@ -30,4 +56,6 @@
n_empty_registers=0,
dropout_rate=0.0,
using_schedule_free=True, # set to True if optimizer wrapped with schedule_free
psgd_calc_hessian=False, # set to True if using PSGD and want to calc and pass in hessian
psgd_precond_update_prob=1.0,
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "image-classification-jax"
version = "0.1.3"
version = "0.1.2"
description = "Run image classification experiments in JAX with ViT, resnet, cifar10, cifar100, imagenette, and imagenet."
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
Expand Down Expand Up @@ -39,6 +39,7 @@ dependencies = [
"tensorflow-cpu",
"tensorflow-datasets",
"wandb",
"psgd-jax",
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ optax
numpy
tensorflow-cpu
tensorflow-datasets
wandb
wandb
psgd-jax

0 comments on commit 568af0d

Please sign in to comment.