Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ViT weights #1145

Merged
merged 16 commits into from
Dec 28, 2022
Merged
10 changes: 8 additions & 2 deletions examples/benchmarking/imagenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def preprocess_image(img, label):
# model size, etc.
loss, acc, top_5 = model.evaluate(test_set, verbose=0)
print(
f"{FLAGS.model_name} achieves {acc} Top-1 Accuracy and {top_5} Top-5 Accuracy on ImageNetV2 with setup:"
f"Benchmark results:\n{'='*25}\n{FLAGS.model_name} achieves: \n - Top-1 Accuracy: {acc*100} \n - Top-5 Accuracy: {top_5*100} \non ImageNetV2 with setup:"
)
print(
f"- model_name: {FLAGS.model_name}\n"
f"- include_rescaling: {FLAGS.include_rescaling}\n"
f"- batch_size: {FLAGS.batch_size}\n"
f"- weights: {FLAGS.weights}\n"
f"- model_kwargs: {FLAGS.model_kwargs}\n"
)
print(FLAGS)
34 changes: 24 additions & 10 deletions keras_cv/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras_cv.layers import TransformerEncoder
from keras_cv.layers.vit_layers import PatchingAndEmbedding
from keras_cv.models import utils
from keras_cv.models.weights import parse_weights

MODEL_CONFIGS = {
"ViTTiny16": {
Expand Down Expand Up @@ -138,15 +139,20 @@
For transfer learning use cases, make sure to read the [guide to transfer
learning & fine-tuning](https://keras.io/guides/transfer_learning/).
Args:
include_rescaling: whether or not to Rescale the inputs.If set to True,
inputs will be passed through a `Rescaling(1/255.0)` layer.
include_rescaling: whether or not to Rescale the inputs. If set to True,
inputs will be passed through a `Rescaling(scale=1./255.0)` layer. Note that ViTs
expect an input range of `[0..1]` if rescaling isn't used. Regardless of whether
you supply `[0..1]` or the input is rescaled to `[0..1]`, the inputs will further be
rescaled to `[-1..1]`.
include_top: whether to include the fully-connected layer at the top of the
network. If provided, classes must be provided.
classes: optional number of classes to classify images into, only to be
specified if `include_top` is True.
weights: one of `None` (random initialization), a pretrained weight file
path, or a reference to pre-trained weights (e.g. 'imagenet/classification')
(see available pre-trained weights in weights.py)
(see available pre-trained weights in weights.py). Note that the 'imagenet'
weights only work on an input shape of (224, 224, 3) due to the input shape dependent
patching and flattening logic.
input_shape: optional shape tuple, defaults to (None, None, 3).
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
to use as image input for the model.
Expand Down Expand Up @@ -265,7 +271,11 @@ def ViT(
x = inputs

if include_rescaling:
x = layers.Rescaling(1 / 255.0)(x)
x = layers.Rescaling(1.0 / 255.0)(x)

# The previous layer rescales [0..255] to [0..1] if applicable
# This one rescales [0..1] to [-1..1] since ViTs expect [-1..1]
x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0)(x)

encoded_patches = PatchingAndEmbedding(project_dim, patch_size)(x)
encoded_patches = layers.Dropout(mlp_dropout)(encoded_patches)
Expand All @@ -292,6 +302,10 @@ def ViT(
output = layers.GlobalAveragePooling1D()(output)

model = keras.Model(inputs=inputs, outputs=output)

if weights is not None:
model.load_weights(weights)

return model


Expand All @@ -314,7 +328,7 @@ def ViTTiny16(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vittiny16"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down Expand Up @@ -351,7 +365,7 @@ def ViTS16(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vits16"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down Expand Up @@ -388,7 +402,7 @@ def ViTB16(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vitb16"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down Expand Up @@ -425,7 +439,7 @@ def ViTL16(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vitl16"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down Expand Up @@ -536,7 +550,7 @@ def ViTS32(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vits32"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down Expand Up @@ -573,7 +587,7 @@ def ViTB32(
include_rescaling,
include_top,
name=name,
weights=weights,
weights=parse_weights(weights, include_top, "vitb32"),
input_shape=input_shape,
input_tensor=input_tensor,
pooling=pooling,
Expand Down
48 changes: 48 additions & 0 deletions keras_cv/models/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,30 @@ def parse_weights(weights, include_top, model_type):
"imagenet": "imagenet/classification-v2",
"imagenet/classification": "imagenet/classification-v2",
},
"vittiny16": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"vits16": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"vitb16": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"vitl16": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"vits32": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
"vitb32": {
"imagenet": "imagenet/classification-v0",
"imagenet/classification": "imagenet/classification-v0",
},
}

WEIGHTS_CONFIG = {
Expand Down Expand Up @@ -132,4 +156,28 @@ def parse_weights(weights, include_top, model_type):
"imagenet/classification-v2": "5ee5a8ac650aaa59342bc48ffe770e6797a5550bcc35961e1d06685292c15921",
"imagenet/classification-v2-notop": "e711c83d6db7034871f6d345a476c8184eab99dbf3ffcec0c1d8445684890ad9",
},
"vittiny16": {
"imagenet/classification-v0": "c8227fde16ec8c2e7ab886169b11b4f0ca9af2696df6d16767db20acc9f6e0dd",
"imagenet/classification-v0-notop": "aa4d727e3c6bd30b20f49d3fa294fb4bbef97365c7dcb5cee9c527e4e83c8f5b",
},
"vits16": {
"imagenet/classification-v0": "4a66a1a70a879ff33a3ca6ca30633b9eadafea84b421c92174557eee83e088b5",
"imagenet/classification-v0-notop": "8d0111eda6692096676a5453abfec5d04c79e2de184b04627b295f10b1949745",
},
"vitb16": {
"imagenet/classification-v0": "6ab4e08c773e08de42023d963a97e905ccba710e2c05ef60c0971978d4a8c41b",
"imagenet/classification-v0-notop": "4a1bdd32889298471cb4f30882632e5744fd519bf1a1525b1fa312fe4ea775ed",
},
"vitl16": {
"imagenet/classification-v0": "5a98000f848f2e813ea896b2528983d8d956f8c4b76ceed0b656219d5b34f7fb",
"imagenet/classification-v0-notop": "40d237c44f14d20337266fce6192c00c2f9b890a463fd7f4cb17e8e35b3f5448",
},
"vits32": {
"imagenet/classification-v0": "f5836e3aff2bab202eaee01d98337a08258159d3b718e0421834e98b3665e10a",
"imagenet/classification-v0-notop": "f3907845eff780a4d29c1c56e0ae053411f02fff6fdce1147c4c3bb2124698cd",
},
"vitb32": {
"imagenet/classification-v0": "73025caa78459dc8f9b1de7b58f1d64e24a823f170d17e25fcc8eb6179bea179",
"imagenet/classification-v0-notop": "f07b80c03336d731a2a3a02af5cac1e9fc9aa62659cd29e2e7e5c7474150cc71",
},
}