Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ViT weights #1145

Merged
merged 16 commits into from
Dec 28, 2022
Merged

Adding ViT weights #1145

merged 16 commits into from
Dec 28, 2022

Conversation

DavidLandup0
Copy link
Contributor

@DavidLandup0 DavidLandup0 commented Dec 15, 2022

What does this PR do?

Since the last branch got a stroke during rebasing, opening a new one.

Since we've got ViT's now, here are the official JAX weights, ported for KCV ViTs, as H5 files:

Inference on the ImageNetV2 validation set (ms/sep, top-1 and top-5 ACC):

Tiny16: 87ms/step - loss: 1.6895 - accuracy: 0.6067 - sparse_top_k_categorical_accuracy: 0.8285
S16: 178ms/step - loss: 1.2600 - accuracy: 0.6884 - sparse_top_k_categorical_accuracy: 0.8916
B16: 464ms/step - loss: 1.0631 - accuracy: 0.7293 - sparse_top_k_categorical_accuracy: 0.9199
L16: 1s/step - loss: 0.9622 - accuracy: 0.7545 - sparse_top_k_categorical_accuracy: 0.9321
S32: 54ms/step - loss: 1.6650 - accuracy: 0.6090 - sparse_top_k_categorical_accuracy: 0.8335
B32: 122ms/step - loss: 1.3684 - accuracy: 0.6650 - sparse_top_k_categorical_accuracy: 0.8781

Note: This is on ImageNetV2, which is much harder than the regular val set. The reported accuracies should be ~10-13% higher on the validation set that KCV models are reported with. For example, 68% from ViT/S on ImageNetV2 is equivalent to ~80% for the regular ImageNet validation set.

The new examples/benchmarking/imagenet_v2.py script reports accuracy for V2. Since I don't have access to the GCS where KCV keeps ImageNet - could we update it to also report accuracy for Val, V2 and Real? I find that the discrepancy in reporting for these different sets in papers tends to be misleading and hard to keep track of. We'd probably want to report for all three rather than just one.

The PR also updates vit.py to parse_weights() using their aliases - though, there are no hashes yet. Can you help with uploading them @ianstenbit?

@LukeWood @tanzhenyu

@LukeWood
Copy link
Contributor

just to be sure, the originals use our same rescaling process, right?

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Dec 15, 2022

If you mean whether the original model used rescaling or not - no, they didn't. To use these, we set include_rescaling=False while initializing a ViT

Edit: They use a Rescaling(scale=1./127.5, offset=-1) instead of Rescaling(scale=1./255.0)

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Dec 15, 2022

BTW, as far as retraining goes, ViT Tiny16 on a script very similar to basic_training with:

  • Horizontal Flip + Mixup (0.2 alpha)
  • Cosine Decay Scheduler with 10k steps
  • Adam with a base lr of 0.0002 * replicas (TPUv3.8)
  • 20 epochs

IMG_20221215_210513

Ended at 32% top-1 accuracy and stopped improving because it only uses MixUp+Horizontal flips. (edited: misread the number initially)

As per https://arxiv.org/abs/2205.01580, you can get 76.8% top-1 accuracy on S16 in just 90 epochs (7h on TPUv3.8) with RandAug, Horizontal Flip, MixUp.

Here is their config: https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_s16_i1k.py

It can perfectly be emulated using our basic_training.py script if we replace SGD with Adam, lower the RandAug intensity and put 10k steps for the scheduler. @ianstenbit, can we do a run to test this?

If this means that we can train all ViTs in-house using existing scripts, that's great news :)

@bhack
Copy link
Contributor

bhack commented Dec 15, 2022

Please note that generally more augmentations we add more slower our training script/experiments will be for our well known accumulating fallbacks

@tanzhenyu tanzhenyu requested a review from ianstenbit December 15, 2022 14:43
@tanzhenyu
Copy link
Contributor

Nice! Are those coming from the checkpoint conversion tool?

@DavidLandup0
Copy link
Contributor Author

The ones on GDrive - yes :)
All produced with the script under /tools here.
The plot is from a training job from scratch.

@tanzhenyu
Copy link
Contributor

If you mean whether the original model used rescaling or not - no, they didn't. To use these, we set include_rescaling=False while initializing a ViT

What input range does the original ViT use?

@DavidLandup0
Copy link
Contributor Author

I remember reading somewhere that it should be -1 to 1, but since the images aren't rescaled for inference and the weights work for that - I assume that the input range of 0-255 since tfds returns an int-based image sef

@tanzhenyu
Copy link
Contributor

LGTM. @ianstenbit how do we want to proceed with the weights here?

@ianstenbit
Copy link
Contributor

LGTM. @ianstenbit how do we want to proceed with the weights here?

I will upload these to GCS and post their hashes here (sometime today)

@ianstenbit
Copy link
Contributor

Hey @DavidLandup0 -- FYI your link:
B16: https://drive.google.com/file/d/1M_zALLuEuoZ1fZNpWPqlD2Ahc9UPB9G4
points to a file called vit_b32.py that has the same size as the b32 weights.

For now I am processing / uploading just the other 5 sets of weights

@ianstenbit
Copy link
Contributor

@DavidLandup0 weights are uploaded, here's the hash info:

    "vittiny16": {
        "imagenet/classification-v0": "c8227fde16ec8c2e7ab886169b11b4f0ca9af2696df6d16767db20acc9f6e0dd",
        "imagenet/classification-v0-notop": "aa4d727e3c6bd30b20f49d3fa294fb4bbef97365c7dcb5cee9c527e4e83c8f5b",
    },
    "vits16": {
        "imagenet/classification-v0": "4a66a1a70a879ff33a3ca6ca30633b9eadafea84b421c92174557eee83e088b5",
        "imagenet/classification-v0-notop": "8d0111eda6692096676a5453abfec5d04c79e2de184b04627b295f10b1949745",
    },
    "vitl16": {
        "imagenet/classification-v0": "5a98000f848f2e813ea896b2528983d8d956f8c4b76ceed0b656219d5b34f7fb",
        "imagenet/classification-v0-notop": "40d237c44f14d20337266fce6192c00c2f9b890a463fd7f4cb17e8e35b3f5448",
    },
    "vits32": {
        "imagenet/classification-v0": "f5836e3aff2bab202eaee01d98337a08258159d3b718e0421834e98b3665e10a",
        "imagenet/classification-v0-notop": "f3907845eff780a4d29c1c56e0ae053411f02fff6fdce1147c4c3bb2124698cd",
    },
    "vitb32": {
        "imagenet/classification-v0": "73025caa78459dc8f9b1de7b58f1d64e24a823f170d17e25fcc8eb6179bea179",
        "imagenet/classification-v0-notop": "f07b80c03336d731a2a3a02af5cac1e9fc9aa62659cd29e2e7e5c7474150cc71",
    },

Also -- let's update the weights docstring to mention that the pre-trained weights only work with input_shape = (224,224,3)

@DavidLandup0
Copy link
Contributor Author

Whoops - it's the same link, by accident. Here's the correct link for B16: https://drive.google.com/file/d/128G4Euk2A6ns-6QaCBp7wHH42lgKs-PA

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Dec 17, 2022

Awesome, thank you for handling this @ianstenbit!
Updated the hashes, versions and added back the erroneously removed configs. I'll be on standby for B16 as well.
Thanks!
Can I add you as the co-author of the PR? :)

@ianstenbit
Copy link
Contributor

@DavidLandup0 B16 weights are uploaded

"vitb16": {
"imagenet/classification-v0": "6ab4e08c773e08de42023d963a97e905ccba710e2c05ef60c0971978d4a8c41b",
"imagenet/classification-v0-notop": "4a1bdd32889298471cb4f30882632e5744fd519bf1a1525b1fa312fe4ea775ed",
},

Thanks for your effort on this!

@DavidLandup0
Copy link
Contributor Author

Updated, thanks for getting the weights uploaded and sliced for this!
Can't wait to fine tune a ViT on something now :D

@LukeWood
Copy link
Contributor

If you mean whether the original model used rescaling or not - no, they didn't. To use these, we set include_rescaling=False while initializing a ViT

Can we test these weights with the proper normalization strategy, and actually enforce that users don't include_rescaling=True and provde some of these weights? This seems like a way we could confuse a LOT of people if:

include_rescaling=True, weights="imagenet/classification-v0"

actually gives bad scores (i.e. 60% accuracy)

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Dec 20, 2022

I can add a check and raise a value error is weights='imagenet' and include_rescaling=True if you think would help. I think that the weights wouldn't work if the input is rescaled, since it'd be receiving what it thinks are fully black images (all pixel values <= 1)

They work with rescaled inputs, but are somewhat lower than when not rescaled. I can post the results when all of the inference runs are finished :)
Maybe we raise a warning, but not a value error for rescaling+weights? @LukeWood

@DavidLandup0
Copy link
Contributor Author

So, this is the script that produced the weight files I shared. These runs produced the metrics from the description of the PR: https://colab.research.google.com/drive/1cnZfZ8jx7sk0zUbfBiF-nLMyZVbQuWc1

Takeaway - the models are created using:

model = eval(f'keras_cv.models.{model_to_convert[1][0]}(include_rescaling=False, include_top=True, classes=1000, weights=None, input_shape=(224, 224, 3))')

Focus on include_rescaling=False.

Here's a notebook with the benchmarking script being run for the models with uploaded weights: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba

Here, when include_rescaling is True - they perform well, and don't otherwise:

ViTB16 achieves: 
 Top-1 Accuracy: 0.08999999845400453 
 Top-5 Accuracy: 0.71000000461936 
 on ImageNetV2 with setup:
model_name: ViTB16
include_rescaling: False
batch_size: 32
weights: imagenet
model_kwargs: {}
ViTB16 achieves: 
 Top-1 Accuracy: 72.67000079154968 
 Top-5 Accuracy: 91.50000214576721 
 on ImageNetV2 with setup:
model_name: ViTB16
include_rescaling: True
batch_size: 32
weights: imagenet
model_kwargs: {}

This is the exact opposite of what you'd expect. Also, the reported metrics are slightly lower than in the conversion notebook (i.e. 68.8% vs 66.9% for S16).

@tanzhenyu
Copy link
Contributor

/gcbrun

@DavidLandup0
Copy link
Contributor Author

Re-run the notebook with ViTTiny, ViTS, ResNet50V2, EfficientNetV2B0 and DenseNet121 for performance reference. Again, they're all ran against ImageNetV2, which is more difficult than the ImageNet-Val dataset used to report accuracies while training here, so all models will have a seemingly low accuracy metric, but it does serve to show the relative difference between them :)
Notably, ViTs should include rescaling.

https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba

@tanzhenyu
Copy link
Contributor

I can add a check and raise a value error is weights='imagenet' and include_rescaling=True if you think would help. I think that the weights wouldn't work if the input is rescaled, since it'd be receiving what it thinks are fully black images (all pixel values <= 1)

They work with rescaled inputs, but are somewhat lower than when not rescaled. I can post the results when all of the inference runs are finished :) Maybe we raise a warning, but not a value error for rescaling+weights? @LukeWood

What is the expected input range when include_rescaling=False? (0, 1) or (-1, 1) or (-0.5, 0.5)?
Given this backbone is porting weights, we probably cannot retrain it by including rescaling, so error out completely when the flag is True sounds reasonable to me.

@DavidLandup0
Copy link
Contributor Author

Okay - ran a few tests, and the expected input range is [-1, 1], without rescaling. I've updated the Rescaling layer in the model to rescale to [-1, 1] instead of [0, 1] in case it's included in the initialization. The run here produces the exact expected results now: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba

Might be a good idea to add the range in the docstrings?

Given this backbone is porting weights, we probably cannot retrain it by including rescaling, so error out completely when the flag is True sounds reasonable to me.

Though, shouldn't this be the case with all other models too? They're trained with rescaled inputs, so if someone tries to load imagenet weights, and disable rescaling, they won't be able to train it further, and the loaded weights will be obsolete. To make use of the pretrained weights, rescaling has to be included, just like with ViTs 🤔

The API ended up being the same, since all models, ViTs included, should do include_rescaling=True during initialization

@LukeWood
Copy link
Contributor

Another alternative is to use 1/255 still, but IF we use the ported weights use a different rescaling layer. WDYT @tanzhenyu ?

@DavidLandup0
Copy link
Contributor Author

Oh, you mean, 1/255 by default and 1/127 with offset=-1 if weights are being loaded?
My main concern in both cases is that we very clearly document this so both the docs pages and the docstrings clearly state the expected input range

@tanzhenyu
Copy link
Contributor

tanzhenyu commented Dec 22, 2022

Another alternative is to use 1/255 still, but IF we use the ported weights use a different rescaling layer. WDYT @tanzhenyu ?

Due to the new cases of Transformer, I think we could probably move to the value_range approach, because include_rescaling means it should rescale to (-1, 1) here, which is confusing.

For this particular PR, it sounds like we can move forward with include_rescaling first

@ianstenbit
Copy link
Contributor

@DavidLandup0 we've discussed offline a bit and would like to take the following approach:

ViT expects inputs to either be [0.1] (with include_rescaling=False) or [0,255] (with include_rescaling=True), and it adds an additional rescaling layer that converts [0,1] images to [-1, 1]

@tanzhenyu
Copy link
Contributor

@DavidLandup0 we've discussed offline a bit and would like to take the following approach:

ViT expects inputs to either be [0.1] (with include_rescaling=False) or [0,255] (with include_rescaling=True), and it adds an additional rescaling layer that converts [0,1] images to [-1, 1]

@DavidLandup0 can you modify the PR so that if include_rescaling=False, it will perform a conversion from [0, 1] to [-1, 1] inside the model?

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Dec 22, 2022

    if include_rescaling:
        x = layers.Rescaling(1.0 / 255.0)(x)

    # The previous layer rescales [0..255] to [0..1] if applicable
    # This one rescales [0..1] to [-1..1] since ViTs expect [-1..1]
    x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0)(x)

The benchmarks run the same as before: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba

Now - what happens if someone passes include_rescaling=False and then supplies [0..255] images instead of the expected [0..1]? The same question applies to all other models.

We can always check the range and rescale, but I don't know if we want to do that.

@tanzhenyu
Copy link
Contributor

    if include_rescaling:
        x = layers.Rescaling(1.0 / 255.0)(x)

    # The previous layer rescales [0..255] to [0..1] if applicable
    # This one rescales [0..1] to [-1..1] since ViTs expect [-1..1]
    x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0)(x)

The benchmarks run the same as before: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba

Now - what happens if someone passes include_rescaling=False and then supplies [0..255] images instead of the expected [0..1]? The same question applies to all other models.

We can always check the range and rescale, but I don't know if we want to do that.

That's the reason include_rescaling is not optional -- users would have to think about the input range before using it

@DavidLandup0
Copy link
Contributor Author

Makes sense.
Should I add anything else?

@tanzhenyu
Copy link
Contributor

/gcbrun

@tanzhenyu
Copy link
Contributor

Makes sense. Should I add anything else?

Looks like test has been failing

Copy link
Contributor

@ianstenbit ianstenbit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM -- thank you David!

@DavidLandup0
Copy link
Contributor Author

Makes sense. Should I add anything else?

Looks like test has been failing

There's a test that does get_layer('rescaling'), but we have two, so they're named rescaling_1 and rescaling_2.
Adding a name to the layers should be enough - update coming in

@DavidLandup0
Copy link
Contributor Author

Can you run /gcbrun again?

@tanzhenyu
Copy link
Contributor

/gcbrun

@tanzhenyu
Copy link
Contributor

The remaining failure is not related to the PR

@tanzhenyu tanzhenyu merged commit 79f0821 into keras-team:master Dec 28, 2022
ghost pushed a commit to y-vectorfield/keras-cv that referenced this pull request Nov 16, 2023
* added weight loading

* formatting

* update weights.py

* update

* b16

* update benchmarking script

* tuned rescaling

* improved formatting for benchmark script

* improved formatting for benchmark script

* standardized rescaling

* updated docs

* layer names
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants