-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model trains but only if I dont JIT the step function? #971
Comments
I think the code can be further simplified codefrom typing import Type
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
# from jaxonmodels.layers.batch_norm import BatchNorm
class Downsample(eqx.Module):
conv: eqx.nn.Conv2d
# bn: BatchNorm
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
key: jt.PRNGKeyArray,
):
_, subkey = jax.random.split(key)
self.conv = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
use_bias=False,
key=subkey,
)
# self.bn = BatchNorm(out_channels, axis_name="batch")
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
x = self.conv(x)
# x, state = self.bn(x, state)
return x, state
class BasicBlock(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
expansion: int = eqx.field(static=True, default=1)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
key, *subkeys = jax.random.split(key, 3)
self.conv1 = eqx.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
use_bias=False,
key=subkeys[0],
)
# self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")
self.downsample = downsample
def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class Bottleneck(eqx.Module):
downsample: Downsample | None
conv1: eqx.nn.Conv2d
# bn1: BatchNorm
conv2: eqx.nn.Conv2d
# bn2: BatchNorm
conv3: eqx.nn.Conv2d
# bn3: BatchNorm
expansion: int = eqx.field(static=True, default=4)
def __init__(
self,
in_channels: int,
out_channels: int,
stride: int,
downsample: Downsample | None,
groups: int,
base_width: int,
dilation: int,
key: jt.PRNGKeyArray,
):
_, *subkeys = jax.random.split(key, 4)
width = int(out_channels * (base_width / 64.0)) * groups
self.conv1 = eqx.nn.Conv2d(
in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
)
# self.bn1 = BatchNorm(width, axis_name="batch")
self.conv2 = eqx.nn.Conv2d(
width,
width,
kernel_size=3,
stride=stride,
groups=groups,
dilation=dilation,
padding=dilation,
use_bias=False,
key=subkeys[1],
)
# self.bn2 = BatchNorm(width, axis_name="batch")
self.conv3 = eqx.nn.Conv2d(
width,
out_channels * self.expansion,
kernel_size=1,
key=subkeys[2],
use_bias=False,
)
# self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")
self.downsample = downsample
def __call__(
self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
i = x
x = self.conv1(x)
# x, state = self.bn1(x, state)
x = jax.nn.relu(x)
x = self.conv2(x)
# x, state = self.bn2(x, state)
x = jax.nn.relu(x)
x = self.conv3(x)
# x, state = self.bn3(x, state)
if self.downsample:
i, state = self.downsample(i, state)
x += i
x = jax.nn.relu(x)
return x, state
class ResNet(eqx.Module):
conv1: eqx.nn.Conv2d
# bn: BatchNorm
mp: eqx.nn.MaxPool2d
layer1: list[BasicBlock | Bottleneck]
layer2: list[BasicBlock | Bottleneck]
layer3: list[BasicBlock | Bottleneck]
layer4: list[BasicBlock | Bottleneck]
avg: eqx.nn.AdaptiveAvgPool2d
fc: eqx.nn.Linear
running_internal_channels: int = eqx.field(static=True, default=64)
dilation: int = eqx.field(static=True, default=1)
def __init__(
self,
block: Type[BasicBlock | Bottleneck],
layers: list[int],
n_classes: int,
zero_init_residual: bool,
groups: int,
width_per_group: int,
replace_stride_with_dilation: list[bool] | None,
key: jt.PRNGKeyArray,
input_channels: int = 3,
):
key, *subkeys = jax.random.split(key, 10)
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
f"`replace_stride_with_dilation` should either be `None` "
f"or have a length of 3, got {replace_stride_with_dilation} instead."
)
self.conv1 = eqx.nn.Conv2d(
in_channels=input_channels,
out_channels=self.running_internal_channels,
kernel_size=7,
stride=2,
padding=3,
use_bias=False,
key=subkeys[0],
)
# self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
block,
64,
layers[0],
stride=1,
dilate=False,
groups=groups,
base_width=width_per_group,
key=subkeys[1],
)
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0],
groups=groups,
base_width=width_per_group,
key=subkeys[2],
)
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
groups=groups,
base_width=width_per_group,
key=subkeys[3],
)
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2],
groups=groups,
base_width=width_per_group,
key=subkeys[4],
)
self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])
if zero_init_residual:
# todo: init last bn layer with zero weights
pass
def _make_layer(
self,
block: Type[BasicBlock | Bottleneck],
out_channels: int,
blocks: int,
stride: int,
dilate: bool,
groups: int,
base_width: int,
key: jt.PRNGKeyArray,
) -> list[BasicBlock | Bottleneck]:
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if (
stride != 1
or self.running_internal_channels != out_channels * block.expansion
):
key, subkey = jax.random.split(key)
downsample = Downsample(
self.running_internal_channels,
out_channels * block.expansion,
stride,
subkey,
)
layers = []
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
groups=groups,
base_width=base_width,
dilation=previous_dilation,
key=subkey,
)
)
self.running_internal_channels = out_channels * block.expansion
for _ in range(1, blocks):
key, subkey = jax.random.split(key)
layers.append(
block(
in_channels=self.running_internal_channels,
out_channels=out_channels,
groups=groups,
base_width=base_width,
dilation=self.dilation,
stride=1,
downsample=None,
key=subkey,
)
)
return layers
def __call__(
self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
x = self.conv1(x)
# x, state = self.bn(x, state)
x = jax.nn.relu(x)
x = self.mp(x)
for layer in self.layer1:
x, state = layer(x, state)
for layer in self.layer2:
x, state = layer(x, state)
for layer in self.layer3:
x, state = layer(x, state)
for layer in self.layer4:
x, state = layer(x, state)
x = self.avg(x)
x = jnp.ravel(x)
x = self.fc(x)
return x, state
def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
key, subkey = jax.random.split(key)
resnet, state = eqx.nn.make_with_state(ResNet)(
BasicBlock,
[2, 2, 2, 2],
n_classes,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
key=key,
)
# initializer = jax.nn.initializers.he_normal()
# is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
# get_weights = lambda m: [
# x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
# ]
# weights = get_weights(resnet)
# new_weights = [
# initializer(subkey, weight.shape, jnp.float32)
# for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
# ]
# resnet = eqx.tree_at(get_weights, resnet, new_weights)
return resnet, state
import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
tf.random.set_seed(0)
(train, test), info = tfds.load(
"cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore
def preprocess(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
mean = tf.constant([0.4914, 0.4822, 0.4465])
std = tf.constant([0.2470, 0.2435, 0.2616])
img = (img - mean) / std # pyright: ignore
img = tf.transpose(img, perm=[2, 0, 1])
# label = tf.one_hot(label, depth=10)
return img, label
def preprocess_train(
img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
img = tf.image.random_crop(img, [32, 32, 3])
img = tf.image.random_flip_left_right(img) # pyright: ignore
return preprocess(img, label)
train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 2
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)
train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)
def loss_fn(
resnet: ResNet,
x: jt.Array,
y: jt.Array,
state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
logits, state = eqx.filter_vmap(
resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
)(x, state)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
return jnp.mean(loss), (logits, state)
@eqx.filter_jit
def step(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
new_r = eqx.apply_updates(resnet, updates)
return new_r, state, opt_state, loss_value, logits, grads, updates
def step_nj(
resnet: jt.PyTree,
state: eqx.nn.State,
x: jt.Array,
y: jt.Array,
optimizer: optax.GradientTransformation,
opt_state: optax.OptState,
):
(loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
loss_fn, has_aux=True
)(resnet, x, y, state)
updates, opt_state = optimizer.update(grads, opt_state, resnet)
new_r = eqx.apply_updates(resnet, updates)
return new_r, state, opt_state, loss_value, logits, grads, updates
resnet, state = resnet18(key=jax.random.key(0), n_classes=10)
learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)
opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))
key = jax.random.key(99)
n_epochs = 100
for epoch in range(n_epochs):
batch_count = len(train_dataset)
for i, (x, y) in enumerate(train_dataset):
y = jnp.array(y, dtype=jnp.int32)
new_j, state, os_j, loss, logits, g, u = step(
resnet, state, x, y, optimizer, opt_state
)
print("\n WJ", state, loss, logits)
y = jnp.array(y, dtype=jnp.int32)
new_nj, state, os_nj, loss, logits_nj, g_nj, u_nj = step_nj(
resnet, state, x, y, optimizer, opt_state
)
print("\n NJ", state, loss, logits_nj)
print(jnp.allclose(logits, logits_nj))
print(eqx.tree_equal(g, g_nj))
print(eqx.tree_equal(u, u_nj))
print(eqx.tree_equal(os_j, os_nj))
print(eqx.tree_equal(new_j, new_nj))
print(eqx.tree_equal(jax.tree.leaves(new_j), jax.tree.leaves(new_nj)))
l = jax.tree.leaves(new_j)
lj = jax.tree.leaves(new_nj)
print(len(l), len(lj))
for i in range(len(l)):
try:
print(i, jnp.linalg.norm(lj[i] - l[i]), jnp.allclose(lj[i], l[i]))
if jnp.isnan(jnp.linalg.norm(lj[i] - l[i])):
print(lj[i], l[i])
except:
print(i, lj[i], l[i], lj[i] == l[i])
break
break The only differences I saw under jit is that the gradient is slightly different (with a norm of 1e-9 I assume that's just within precision), but not always. I would be surprise if this is the source of the problem, but I'm just trying to narrow it down since it's a very large setup currently
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Preface (not super relevant; can be skipped)
Ok, so this is a weird one, which took me HOURS to find. I wanted to implement ResNet and train it on Cifar10 for another YT video, so I started hacking away ported the PyTorch implementation to Equinox. So far, so good. But I couldn't for the love of God get it to train. The PyTorch version had no problem training - even without any preprocessing of the data, no fancy-schmancy learning rate schedulers, just the most straight forward implementation you can think of.
I thought I was going crazy; I thought maybe it was because of the BatchNorm (because I saw a couple of open issues) - so I implemented a slightly different version that matches PyTorch version EXACTLY. But to no avail. I started to check the intermediate outputs of the network, maybe something is off there? No. Then I even turned off BatchNorm entirely in both networks. The PyTorch one - even without BatchNorm - trained no problems at all. But not my version; so it's definitely not because of the BatchNorm discrepancies. At this point it's basically just a large CNN with some residual connections.
Copy-pastable version of ResNet (jaxtyping required)
The issue
I was started to get desperate and started to check the gradients, for which I needed to un-JIT the step function to see the gradient print statements. And there it is: after removing the
eqx.filter_jit
wrapper around the step function, the network started to train.Copy-pasteable training loop code (requires tensorflow tensorflow_datasets clu tqdm jaxtyping)
The relevant part (tl;dr)
The issue is here:
From my perspective, this looks just like standard JAX "boilerplate" code. I see no reason, why JITting the
step
function would interfere with training the model.My other attemps
So perhaps I can get rid of the state, I thought, since I don't even use BatchNorm anymore. But that makes no difference. I tried JITting a smaller portion, as shown in the RNN example
But the equivalent version didn't improve the model.
I spent all day on this and am now out of options and in German we'd say "es ist wie verhext" , so perhaps anyone here has an idea? ANY help is HIGHLY appreciated.
The text was updated successfully, but these errors were encountered: