-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding ViT weights #1145
Adding ViT weights #1145
Conversation
just to be sure, the originals use our same rescaling process, right? |
Edit: They use a |
BTW, as far as retraining goes, ViT Tiny16 on a script very similar to
Ended at 32% top-1 accuracy and stopped improving because it only uses MixUp+Horizontal flips. (edited: misread the number initially) As per https://arxiv.org/abs/2205.01580, you can get 76.8% top-1 accuracy on S16 in just 90 epochs (7h on TPUv3.8) with RandAug, Horizontal Flip, MixUp. Here is their config: https://github.com/google-research/big_vision/blob/main/big_vision/configs/vit_s16_i1k.py It can perfectly be emulated using our If this means that we can train all ViTs in-house using existing scripts, that's great news :) |
Please note that generally more augmentations we add more slower our training script/experiments will be for our well known accumulating fallbacks |
Nice! Are those coming from the checkpoint conversion tool? |
The ones on GDrive - yes :) |
What input range does the original ViT use? |
I remember reading somewhere that it should be -1 to 1, but since the images aren't rescaled for inference and the weights work for that - I assume that the input range of 0-255 since tfds returns an int-based image sef |
LGTM. @ianstenbit how do we want to proceed with the weights here? |
I will upload these to GCS and post their hashes here (sometime today) |
Hey @DavidLandup0 -- FYI your link: For now I am processing / uploading just the other 5 sets of weights |
@DavidLandup0 weights are uploaded, here's the hash info:
Also -- let's update the |
Whoops - it's the same link, by accident. Here's the correct link for B16: https://drive.google.com/file/d/128G4Euk2A6ns-6QaCBp7wHH42lgKs-PA |
Awesome, thank you for handling this @ianstenbit! |
@DavidLandup0 B16 weights are uploaded "vitb16": { Thanks for your effort on this! |
Updated, thanks for getting the weights uploaded and sliced for this! |
Can we test these weights with the proper normalization strategy, and actually enforce that users don't include_rescaling=True and provde some of these weights? This seems like a way we could confuse a LOT of people if:
actually gives bad scores (i.e. 60% accuracy) |
I can add a check and raise a value error is They work with rescaled inputs, but are somewhat lower than when not rescaled. I can post the results when all of the inference runs are finished :) |
So, this is the script that produced the weight files I shared. These runs produced the metrics from the description of the PR: https://colab.research.google.com/drive/1cnZfZ8jx7sk0zUbfBiF-nLMyZVbQuWc1 Takeaway - the models are created using:
Here's a notebook with the benchmarking script being run for the models with uploaded weights: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba Here, when
This is the exact opposite of what you'd expect. Also, the reported metrics are slightly lower than in the conversion notebook (i.e. 68.8% vs 66.9% for S16). |
/gcbrun |
Re-run the notebook with ViTTiny, ViTS, ResNet50V2, EfficientNetV2B0 and DenseNet121 for performance reference. Again, they're all ran against ImageNetV2, which is more difficult than the ImageNet-Val dataset used to report accuracies while training here, so all models will have a seemingly low accuracy metric, but it does serve to show the relative difference between them :) https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba |
What is the expected input range when |
Okay - ran a few tests, and the expected input range is [-1, 1], without rescaling. I've updated the Might be a good idea to add the range in the docstrings?
Though, shouldn't this be the case with all other models too? They're trained with rescaled inputs, so if someone tries to load imagenet weights, and disable rescaling, they won't be able to train it further, and the loaded weights will be obsolete. To make use of the pretrained weights, rescaling has to be included, just like with ViTs 🤔 The API ended up being the same, since all models, ViTs included, should do |
Another alternative is to use 1/255 still, but IF we use the ported weights use a different rescaling layer. WDYT @tanzhenyu ? |
Oh, you mean, 1/255 by default and 1/127 with offset=-1 if weights are being loaded? |
Due to the new cases of Transformer, I think we could probably move to the For this particular PR, it sounds like we can move forward with |
@DavidLandup0 we've discussed offline a bit and would like to take the following approach: ViT expects inputs to either be [0.1] (with include_rescaling=False) or [0,255] (with include_rescaling=True), and it adds an additional rescaling layer that converts [0,1] images to [-1, 1] |
@DavidLandup0 can you modify the PR so that if |
The benchmarks run the same as before: https://colab.research.google.com/drive/1KWSfkoHrZl30NMEJ7vKxAtAgRwIrT8ba Now - what happens if someone passes We can always check the range and rescale, but I don't know if we want to do that. |
That's the reason |
Makes sense. |
/gcbrun |
Looks like test has been failing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM -- thank you David!
There's a test that does |
Can you run /gcbrun again? |
/gcbrun |
The remaining failure is not related to the PR |
* added weight loading * formatting * update weights.py * update * b16 * update benchmarking script * tuned rescaling * improved formatting for benchmark script * improved formatting for benchmark script * standardized rescaling * updated docs * layer names
What does this PR do?
Since the last branch got a stroke during rebasing, opening a new one.
Since we've got ViT's now, here are the official JAX weights, ported for KCV ViTs, as H5 files:
Inference on the ImageNetV2 validation set (ms/sep, top-1 and top-5 ACC):
Tiny16: 87ms/step - loss: 1.6895 - accuracy: 0.6067 - sparse_top_k_categorical_accuracy: 0.8285
S16: 178ms/step - loss: 1.2600 - accuracy: 0.6884 - sparse_top_k_categorical_accuracy: 0.8916
B16: 464ms/step - loss: 1.0631 - accuracy: 0.7293 - sparse_top_k_categorical_accuracy: 0.9199
L16: 1s/step - loss: 0.9622 - accuracy: 0.7545 - sparse_top_k_categorical_accuracy: 0.9321
S32: 54ms/step - loss: 1.6650 - accuracy: 0.6090 - sparse_top_k_categorical_accuracy: 0.8335
B32: 122ms/step - loss: 1.3684 - accuracy: 0.6650 - sparse_top_k_categorical_accuracy: 0.8781
Note: This is on ImageNetV2, which is much harder than the regular val set. The reported accuracies should be ~10-13% higher on the validation set that KCV models are reported with. For example, 68% from ViT/S on ImageNetV2 is equivalent to ~80% for the regular ImageNet validation set.
The new
examples/benchmarking/imagenet_v2.py
script reports accuracy for V2. Since I don't have access to the GCS where KCV keeps ImageNet - could we update it to also report accuracy for Val, V2 and Real? I find that the discrepancy in reporting for these different sets in papers tends to be misleading and hard to keep track of. We'd probably want to report for all three rather than just one.The PR also updates
vit.py
toparse_weights()
using their aliases - though, there are no hashes yet. Can you help with uploading them @ianstenbit?@LukeWood @tanzhenyu