Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model trains but only if I dont JIT the step function? #971

Open
Artur-Galstyan opened this issue Mar 12, 2025 · 1 comment
Open

Model trains but only if I dont JIT the step function? #971

Artur-Galstyan opened this issue Mar 12, 2025 · 1 comment

Comments

@Artur-Galstyan
Copy link
Contributor

Preface (not super relevant; can be skipped)

Ok, so this is a weird one, which took me HOURS to find. I wanted to implement ResNet and train it on Cifar10 for another YT video, so I started hacking away ported the PyTorch implementation to Equinox. So far, so good. But I couldn't for the love of God get it to train. The PyTorch version had no problem training - even without any preprocessing of the data, no fancy-schmancy learning rate schedulers, just the most straight forward implementation you can think of.

I thought I was going crazy; I thought maybe it was because of the BatchNorm (because I saw a couple of open issues) - so I implemented a slightly different version that matches PyTorch version EXACTLY. But to no avail. I started to check the intermediate outputs of the network, maybe something is off there? No. Then I even turned off BatchNorm entirely in both networks. The PyTorch one - even without BatchNorm - trained no problems at all. But not my version; so it's definitely not because of the BatchNorm discrepancies. At this point it's basically just a large CNN with some residual connections.

Copy-pastable version of ResNet (jaxtyping required)

from typing import Type

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt

# from jaxonmodels.layers.batch_norm import BatchNorm


class Downsample(eqx.Module):
    conv: eqx.nn.Conv2d
    # bn: BatchNorm

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        key: jt.PRNGKeyArray,
    ):
        _, subkey = jax.random.split(key)
        self.conv = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            use_bias=False,
            key=subkey,
        )

        # self.bn = BatchNorm(out_channels, axis_name="batch")

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        x = self.conv(x)
        # x, state = self.bn(x, state)

        return x, state


class BasicBlock(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    expansion: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        key, *subkeys = jax.random.split(key, 3)

        self.conv1 = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            use_bias=False,
            key=subkeys[0],
        )
        # self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_bias=False,
            key=subkeys[1],
        )
        # self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.downsample = downsample

    def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)

        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)

        return x, state


class Bottleneck(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    conv3: eqx.nn.Conv2d
    # bn3: BatchNorm

    expansion: int = eqx.field(static=True, default=4)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        _, *subkeys = jax.random.split(key, 4)

        width = int(out_channels * (base_width / 64.0)) * groups
        self.conv1 = eqx.nn.Conv2d(
            in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
        )
        # self.bn1 = BatchNorm(width, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            width,
            width,
            kernel_size=3,
            stride=stride,
            groups=groups,
            dilation=dilation,
            padding=dilation,
            use_bias=False,
            key=subkeys[1],
        )

        # self.bn2 = BatchNorm(width, axis_name="batch")

        self.conv3 = eqx.nn.Conv2d(
            width,
            out_channels * self.expansion,
            kernel_size=1,
            key=subkeys[2],
            use_bias=False,
        )

        # self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")

        self.downsample = downsample

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)
        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)
        x = jax.nn.relu(x)

        x = self.conv3(x)
        # x, state = self.bn3(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)
        return x, state


class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    # bn: BatchNorm
    mp: eqx.nn.MaxPool2d

    layer1: list[BasicBlock | Bottleneck]
    layer2: list[BasicBlock | Bottleneck]
    layer3: list[BasicBlock | Bottleneck]
    layer4: list[BasicBlock | Bottleneck]

    avg: eqx.nn.AdaptiveAvgPool2d
    fc: eqx.nn.Linear

    running_internal_channels: int = eqx.field(static=True, default=64)
    dilation: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        block: Type[BasicBlock | Bottleneck],
        layers: list[int],
        n_classes: int,
        zero_init_residual: bool,
        groups: int,
        width_per_group: int,
        replace_stride_with_dilation: list[bool] | None,
        key: jt.PRNGKeyArray,
        input_channels: int = 3,
    ):
        key, *subkeys = jax.random.split(key, 10)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                f"`replace_stride_with_dilation` should either be `None` "
                f"or have a length of 3, got {replace_stride_with_dilation} instead."
            )

        self.conv1 = eqx.nn.Conv2d(
            in_channels=input_channels,
            out_channels=self.running_internal_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            use_bias=False,
            key=subkeys[0],
        )

        # self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
        self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(
            block,
            64,
            layers[0],
            stride=1,
            dilate=False,
            groups=groups,
            base_width=width_per_group,
            key=subkeys[1],
        )
        self.layer2 = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[2],
        )
        self.layer3 = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[3],
        )
        self.layer4 = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[4],
        )

        self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
        self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])

        if zero_init_residual:
            # todo: init last bn layer with zero weights
            pass

    def _make_layer(
        self,
        block: Type[BasicBlock | Bottleneck],
        out_channels: int,
        blocks: int,
        stride: int,
        dilate: bool,
        groups: int,
        base_width: int,
        key: jt.PRNGKeyArray,
    ) -> list[BasicBlock | Bottleneck]:
        downsample = None
        previous_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1

        if (
            stride != 1
            or self.running_internal_channels != out_channels * block.expansion
        ):
            key, subkey = jax.random.split(key)
            downsample = Downsample(
                self.running_internal_channels,
                out_channels * block.expansion,
                stride,
                subkey,
            )
        layers = []

        key, subkey = jax.random.split(key)
        layers.append(
            block(
                in_channels=self.running_internal_channels,
                out_channels=out_channels,
                stride=stride,
                downsample=downsample,
                groups=groups,
                base_width=base_width,
                dilation=previous_dilation,
                key=subkey,
            )
        )

        self.running_internal_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            key, subkey = jax.random.split(key)
            layers.append(
                block(
                    in_channels=self.running_internal_channels,
                    out_channels=out_channels,
                    groups=groups,
                    base_width=base_width,
                    dilation=self.dilation,
                    stride=1,
                    downsample=None,
                    key=subkey,
                )
            )

        return layers

    def __call__(
        self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
        x = self.conv1(x)
        # x, state = self.bn(x, state)
        x = jax.nn.relu(x)
        x = self.mp(x)

        for layer in self.layer1:
            x, state = layer(x, state)

        for layer in self.layer2:
            x, state = layer(x, state)

        for layer in self.layer3:
            x, state = layer(x, state)

        for layer in self.layer4:
            x, state = layer(x, state)

        x = self.avg(x)
        x = jnp.ravel(x)

        x = self.fc(x)

        return x, state


def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    key, subkey = jax.random.split(key)
    resnet, state = eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [2, 2, 2, 2],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )

    # initializer = jax.nn.initializers.he_normal()
    # is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
    # get_weights = lambda m: [
    #     x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
    # ]
    # weights = get_weights(resnet)
    # new_weights = [
    #     initializer(subkey, weight.shape, jnp.float32)
    #     for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    # ]
    # resnet = eqx.tree_at(get_weights, resnet, new_weights)

    return resnet, state


def resnet34(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet50(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet101(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnet152(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 8, 36, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext50_32x4d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=32,
        width_per_group=4,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext101_32x8d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=32,
        width_per_group=8,
        replace_stride_with_dilation=None,
        key=key,
    )


def resnext101_64x4d(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=64,
        width_per_group=4,
        replace_stride_with_dilation=None,
        key=key,
    )


def wide_resnet50_2(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 6, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64 * 2,
        replace_stride_with_dilation=None,
        key=key,
    )


def wide_resnet101_2(
    key: jt.PRNGKeyArray, n_classes=1000
) -> tuple[ResNet, eqx.nn.State]:
    return eqx.nn.make_with_state(ResNet)(
        Bottleneck,
        [3, 4, 23, 3],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64 * 2,
        replace_stride_with_dilation=None,
        key=key,
    )

The issue

I was started to get desperate and started to check the gradients, for which I needed to un-JIT the step function to see the gradient print statements. And there it is: after removing the eqx.filter_jit wrapper around the step function, the network started to train.

Copy-pasteable training loop code (requires tensorflow tensorflow_datasets clu tqdm jaxtyping)

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from clu import metrics
from tqdm import tqdm

# from jaxonmodels.models.resnet import ResNet, resnet18
# copy paste here the code from the details above

(train, test), info = tfds.load(
    "cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore


def preprocess(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
    mean = tf.constant([0.4914, 0.4822, 0.4465])
    std = tf.constant([0.2470, 0.2435, 0.2616])
    img = (img - mean) / std # pyright: ignore

    img = tf.transpose(img, perm=[2, 0, 1])

    # label = tf.one_hot(label, depth=10)

    return img, label


def preprocess_train(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
    img = tf.image.random_crop(img, [32, 32, 3])
    img = tf.image.random_flip_left_right(img)  # pyright: ignore

    return preprocess(img, label)


train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 128
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)



def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

# @eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    resnet = eqx.apply_updates(resnet, updates)
    return resnet, state, opt_state, loss_value, logits



class TrainMetrics(eqx.Module, metrics.Collection):
    loss: metrics.Average.from_output("loss")  # pyright: ignore
    accuracy: metrics.Accuracy


def eval(
    resnet: ResNet, test_dataset, state, key: jt.PRNGKeyArray
) -> TrainMetrics:
    eval_metrics = TrainMetrics.empty()
    for x, y in test_dataset:
        y = jnp.array(y, dtype=jnp.int32)
        loss, (logits, state) = loss_fn(resnet, x, y, state)
        eval_metrics = eval_metrics.merge(
            TrainMetrics.single_from_model_output(
                logits=logits, labels=y, loss=loss
            )
        )

    return eval_metrics


train_metrics = TrainMetrics.empty()

resnet, state = resnet18(key=jax.random.key(0), n_classes=10)

learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)

opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))

key = jax.random.key(99)
n_epochs = 100


for epoch in range(n_epochs):
    batch_count = len(train_dataset)

    pbar = tqdm(enumerate(train_dataset), total=batch_count, desc=f"Epoch {epoch}")
    for i, (x, y) in pbar:
        y = jnp.array(y, dtype=jnp.int32)
        resnet, state, opt_state, loss, logits = step(
            resnet, state, x, y, optimizer, opt_state
        )
        train_metrics = train_metrics.merge(
            TrainMetrics.single_from_model_output(
                logits=logits, labels=y, loss=loss
            )
        )

        vals = train_metrics.compute()
        pbar.set_postfix(
            {"loss": f"{vals['loss']:.4f}", "acc": f"{vals['accuracy']:.4f}"}
        )
    key, subkey = jax.random.split(key)
    eval_metrics = eval(resnet, test_dataset, state, subkey)
    evals = eval_metrics.compute()
    print(
        f"Epoch {epoch}: "
        f"test_loss={evals['loss']:.4f}, "
        f"test_acc={evals['accuracy']:.4f}"
    )

The relevant part (tl;dr)

The issue is here:

def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

# @eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    resnet = eqx.apply_updates(resnet, updates)
    return resnet, state, opt_state, loss_value, logits

From my perspective, this looks just like standard JAX "boilerplate" code. I see no reason, why JITting the step function would interfere with training the model.

My other attemps

So perhaps I can get rid of the state, I thought, since I don't even use BatchNorm anymore. But that makes no difference. I tried JITting a smaller portion, as shown in the RNN example

@eqx.filter_value_and_grad
    def compute_loss(model, x, y):
        pred_y = jax.vmap(model)(x)
        # Trains with respect to binary cross-entropy
        return -jnp.mean(y * jnp.log(pred_y) + (1 - y) * jnp.log(1 - pred_y))

But the equivalent version didn't improve the model.

I spent all day on this and am now out of options and in German we'd say "es ist wie verhext" , so perhaps anyone here has an idea? ANY help is HIGHLY appreciated.

@lockwo
Copy link
Contributor

lockwo commented Mar 13, 2025

I think the code can be further simplified

code
from typing import Type

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt

# from jaxonmodels.layers.batch_norm import BatchNorm


class Downsample(eqx.Module):
    conv: eqx.nn.Conv2d
    # bn: BatchNorm

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        key: jt.PRNGKeyArray,
    ):
        _, subkey = jax.random.split(key)
        self.conv = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=stride,
            use_bias=False,
            key=subkey,
        )

        # self.bn = BatchNorm(out_channels, axis_name="batch")

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        x = self.conv(x)
        # x, state = self.bn(x, state)

        return x, state


class BasicBlock(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    expansion: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        key, *subkeys = jax.random.split(key, 3)

        self.conv1 = eqx.nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            use_bias=False,
            key=subkeys[0],
        )
        # self.bn1 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_bias=False,
            key=subkeys[1],
        )
        # self.bn2 = BatchNorm(input_size=out_channels, axis_name="batch")

        self.downsample = downsample

    def __call__(self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State):
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)

        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)

        return x, state


class Bottleneck(eqx.Module):
    downsample: Downsample | None

    conv1: eqx.nn.Conv2d
    # bn1: BatchNorm

    conv2: eqx.nn.Conv2d
    # bn2: BatchNorm

    conv3: eqx.nn.Conv2d
    # bn3: BatchNorm

    expansion: int = eqx.field(static=True, default=4)

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int,
        downsample: Downsample | None,
        groups: int,
        base_width: int,
        dilation: int,
        key: jt.PRNGKeyArray,
    ):
        _, *subkeys = jax.random.split(key, 4)

        width = int(out_channels * (base_width / 64.0)) * groups
        self.conv1 = eqx.nn.Conv2d(
            in_channels, width, kernel_size=1, use_bias=False, key=subkeys[0]
        )
        # self.bn1 = BatchNorm(width, axis_name="batch")

        self.conv2 = eqx.nn.Conv2d(
            width,
            width,
            kernel_size=3,
            stride=stride,
            groups=groups,
            dilation=dilation,
            padding=dilation,
            use_bias=False,
            key=subkeys[1],
        )

        # self.bn2 = BatchNorm(width, axis_name="batch")

        self.conv3 = eqx.nn.Conv2d(
            width,
            out_channels * self.expansion,
            kernel_size=1,
            key=subkeys[2],
            use_bias=False,
        )

        # self.bn3 = BatchNorm(out_channels * self.expansion, axis_name="batch")

        self.downsample = downsample

    def __call__(
        self, x: jt.Float[jt.Array, "c_in h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, "c_out*e h/s w/s"], eqx.nn.State]:
        i = x

        x = self.conv1(x)
        # x, state = self.bn1(x, state)
        x = jax.nn.relu(x)

        x = self.conv2(x)
        # x, state = self.bn2(x, state)
        x = jax.nn.relu(x)

        x = self.conv3(x)
        # x, state = self.bn3(x, state)

        if self.downsample:
            i, state = self.downsample(i, state)

        x += i
        x = jax.nn.relu(x)
        return x, state


class ResNet(eqx.Module):
    conv1: eqx.nn.Conv2d
    # bn: BatchNorm
    mp: eqx.nn.MaxPool2d

    layer1: list[BasicBlock | Bottleneck]
    layer2: list[BasicBlock | Bottleneck]
    layer3: list[BasicBlock | Bottleneck]
    layer4: list[BasicBlock | Bottleneck]

    avg: eqx.nn.AdaptiveAvgPool2d
    fc: eqx.nn.Linear

    running_internal_channels: int = eqx.field(static=True, default=64)
    dilation: int = eqx.field(static=True, default=1)

    def __init__(
        self,
        block: Type[BasicBlock | Bottleneck],
        layers: list[int],
        n_classes: int,
        zero_init_residual: bool,
        groups: int,
        width_per_group: int,
        replace_stride_with_dilation: list[bool] | None,
        key: jt.PRNGKeyArray,
        input_channels: int = 3,
    ):
        key, *subkeys = jax.random.split(key, 10)

        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                f"`replace_stride_with_dilation` should either be `None` "
                f"or have a length of 3, got {replace_stride_with_dilation} instead."
            )

        self.conv1 = eqx.nn.Conv2d(
            in_channels=input_channels,
            out_channels=self.running_internal_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            use_bias=False,
            key=subkeys[0],
        )

        # self.bn = BatchNorm(self.running_internal_channels, axis_name="batch")
        self.mp = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(
            block,
            64,
            layers[0],
            stride=1,
            dilate=False,
            groups=groups,
            base_width=width_per_group,
            key=subkeys[1],
        )
        self.layer2 = self._make_layer(
            block,
            128,
            layers[1],
            stride=2,
            dilate=replace_stride_with_dilation[0],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[2],
        )
        self.layer3 = self._make_layer(
            block,
            256,
            layers[2],
            stride=2,
            dilate=replace_stride_with_dilation[1],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[3],
        )
        self.layer4 = self._make_layer(
            block,
            512,
            layers[3],
            stride=2,
            dilate=replace_stride_with_dilation[2],
            groups=groups,
            base_width=width_per_group,
            key=subkeys[4],
        )

        self.avg = eqx.nn.AdaptiveAvgPool2d(target_shape=(1, 1))
        self.fc = eqx.nn.Linear(512 * block.expansion, n_classes, key=subkeys[-1])

        if zero_init_residual:
            # todo: init last bn layer with zero weights
            pass

    def _make_layer(
        self,
        block: Type[BasicBlock | Bottleneck],
        out_channels: int,
        blocks: int,
        stride: int,
        dilate: bool,
        groups: int,
        base_width: int,
        key: jt.PRNGKeyArray,
    ) -> list[BasicBlock | Bottleneck]:
        downsample = None
        previous_dilation = self.dilation

        if dilate:
            self.dilation *= stride
            stride = 1

        if (
            stride != 1
            or self.running_internal_channels != out_channels * block.expansion
        ):
            key, subkey = jax.random.split(key)
            downsample = Downsample(
                self.running_internal_channels,
                out_channels * block.expansion,
                stride,
                subkey,
            )
        layers = []

        key, subkey = jax.random.split(key)
        layers.append(
            block(
                in_channels=self.running_internal_channels,
                out_channels=out_channels,
                stride=stride,
                downsample=downsample,
                groups=groups,
                base_width=base_width,
                dilation=previous_dilation,
                key=subkey,
            )
        )

        self.running_internal_channels = out_channels * block.expansion

        for _ in range(1, blocks):
            key, subkey = jax.random.split(key)
            layers.append(
                block(
                    in_channels=self.running_internal_channels,
                    out_channels=out_channels,
                    groups=groups,
                    base_width=base_width,
                    dilation=self.dilation,
                    stride=1,
                    downsample=None,
                    key=subkey,
                )
            )

        return layers

    def __call__(
        self, x: jt.Float[jt.Array, "c h w"], state: eqx.nn.State
    ) -> tuple[jt.Float[jt.Array, " n_classes"], eqx.nn.State]:
        x = self.conv1(x)
        # x, state = self.bn(x, state)
        x = jax.nn.relu(x)
        x = self.mp(x)

        for layer in self.layer1:
            x, state = layer(x, state)

        for layer in self.layer2:
            x, state = layer(x, state)

        for layer in self.layer3:
            x, state = layer(x, state)

        for layer in self.layer4:
            x, state = layer(x, state)

        x = self.avg(x)
        x = jnp.ravel(x)

        x = self.fc(x)

        return x, state


def resnet18(key: jt.PRNGKeyArray, n_classes=1000) -> tuple[ResNet, eqx.nn.State]:
    key, subkey = jax.random.split(key)
    resnet, state = eqx.nn.make_with_state(ResNet)(
        BasicBlock,
        [2, 2, 2, 2],
        n_classes,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        key=key,
    )

    # initializer = jax.nn.initializers.he_normal()
    # is_conv2d = lambda x: isinstance(x, eqx.nn.Conv2d)
    # get_weights = lambda m: [
    #     x.weight for x in jax.tree.leaves(m, is_leaf=is_conv2d) if is_conv2d(x)
    # ]
    # weights = get_weights(resnet)
    # new_weights = [
    #     initializer(subkey, weight.shape, jnp.float32)
    #     for weight, subkey in zip(weights, jax.random.split(key, len(weights)))
    # ]
    # resnet = eqx.tree_at(get_weights, resnet, new_weights)

    return resnet, state

import equinox as eqx
import jax
import jax.numpy as jnp
import jaxtyping as jt
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm

tf.random.set_seed(0)

(train, test), info = tfds.load(
    "cifar10", split=["train", "test"], with_info=True, as_supervised=True
) # pyright: ignore


def preprocess(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.cast(img, tf.float32) / 255.0 # pyright: ignore
    mean = tf.constant([0.4914, 0.4822, 0.4465])
    std = tf.constant([0.2470, 0.2435, 0.2616])
    img = (img - mean) / std # pyright: ignore

    img = tf.transpose(img, perm=[2, 0, 1])

    # label = tf.one_hot(label, depth=10)

    return img, label


def preprocess_train(
    img: jt.Float[tf.Tensor, "h w c"], label: jt.Int[tf.Tensor, ""]
) -> tuple[jt.Float[tf.Tensor, "h w c"], jt.Int[tf.Tensor, "1 n_classes"]]:
    img = tf.pad(img, [[4, 4], [4, 4], [0, 0]], mode="REFLECT")
    img = tf.image.random_crop(img, [32, 32, 3])
    img = tf.image.random_flip_left_right(img)  # pyright: ignore

    return preprocess(img, label)


train_dataset = train.map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
SHUFFLE_VAL = len(train_dataset) // 1000
BATCH_SIZE = 2
train_dataset = train_dataset.shuffle(SHUFFLE_VAL)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = test.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = tfds.as_numpy(train_dataset)
test_dataset = tfds.as_numpy(test_dataset)

def loss_fn(
    resnet: ResNet,
    x: jt.Array,
    y: jt.Array,
    state: eqx.nn.State,
) -> tuple[jt.Array, tuple[jt.Array, eqx.nn.State]]:
    logits, state = eqx.filter_vmap(
        resnet, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
    )(x, state)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
    return jnp.mean(loss), (logits, state)

@eqx.filter_jit
def step(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    new_r = eqx.apply_updates(resnet, updates)
    return new_r, state, opt_state, loss_value, logits, grads, updates

def step_nj(
    resnet: jt.PyTree,
    state: eqx.nn.State,
    x: jt.Array,
    y: jt.Array,
    optimizer: optax.GradientTransformation,
    opt_state: optax.OptState,
):
    (loss_value, (logits, state)), grads = eqx.filter_value_and_grad(
        loss_fn, has_aux=True
    )(resnet, x, y, state)
    updates, opt_state = optimizer.update(grads, opt_state, resnet)
    new_r = eqx.apply_updates(resnet, updates)
    return new_r, state, opt_state, loss_value, logits, grads, updates

resnet, state = resnet18(key=jax.random.key(0), n_classes=10)

learning_rate = 0.1
weight_decay = 5e-4
optimizer = optax.sgd(learning_rate)

opt_state = optimizer.init(eqx.filter(resnet, eqx.is_inexact_array_like))

key = jax.random.key(99)
n_epochs = 100

for epoch in range(n_epochs):
    batch_count = len(train_dataset)

    for i, (x, y) in enumerate(train_dataset):
        y = jnp.array(y, dtype=jnp.int32)
        new_j, state, os_j, loss, logits, g, u = step(
            resnet, state, x, y, optimizer, opt_state
        )
        print("\n WJ", state, loss, logits)
        y = jnp.array(y, dtype=jnp.int32)
        new_nj, state, os_nj, loss, logits_nj, g_nj, u_nj = step_nj(
            resnet, state, x, y, optimizer, opt_state
        )
        print("\n NJ", state, loss, logits_nj)
        print(jnp.allclose(logits, logits_nj))
        print(eqx.tree_equal(g, g_nj))
        print(eqx.tree_equal(u, u_nj))
        print(eqx.tree_equal(os_j, os_nj))
        print(eqx.tree_equal(new_j, new_nj))
        print(eqx.tree_equal(jax.tree.leaves(new_j), jax.tree.leaves(new_nj)))
        l = jax.tree.leaves(new_j)
        lj = jax.tree.leaves(new_nj)
        print(len(l), len(lj))
        for i in range(len(l)):
          try:
            print(i, jnp.linalg.norm(lj[i] - l[i]), jnp.allclose(lj[i], l[i]))
            if jnp.isnan(jnp.linalg.norm(lj[i] - l[i])):
              print(lj[i], l[i])
          except:
            print(i, lj[i], l[i], lj[i] == l[i])
        break
    break

The only differences I saw under jit is that the gradient is slightly different (with a norm of 1e-9 I assume that's just within precision), but not always. I would be surprise if this is the source of the problem, but I'm just trying to narrow it down since it's a very large setup currently

WJ State() 2.333593 [[-0.04504904  0.03363977 -0.03377024 -0.02731024 -0.00935392 -0.07189757
   0.0312047   0.04465355 -0.03419897  0.01027971]
 [-0.04424123  0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
   0.03016024  0.04261264 -0.03077785  0.0194398 ]]

 NJ State() 2.333593 [[-0.04504904  0.03363976 -0.03377024 -0.02731025 -0.00935392 -0.07189757
   0.0312047   0.04465355 -0.03419897  0.01027971]
 [-0.04424123  0.03772916 -0.02880739 -0.01243044 -0.00705145 -0.06763338
   0.03016024  0.04261264 -0.03077785  0.0194398 ]]
True
False
False
True
False
False
22 22
0 2.735993e-08 True
1 1.10683e-08 True
2 1.1603525e-08 True
3 1.2323362e-08 True
4 1.2514869e-08 True
5 1.3517954e-08 True
6 2.0453514e-08 True
7 2.238987e-08 True
8 1.1969572e-08 True
9 9.68429e-09 True
10 1.0157422e-08 True
11 1.7057491e-08 True
12 1.5252258e-08 True
13 1.0765473e-08 True
14 6.8855863e-09 True
15 6.1813026e-09 True
16 9.383501e-09 True
17 5.7024003e-09 True
18 4.0014845e-09 True
19 1.3972744e-09 True
20 2.6966664e-09 True
21 0.0 True

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants