diff --git a/examples/benchmarking/imagenet_v2.py b/examples/benchmarking/imagenet_v2.py index d2308960b7..865c74340b 100644 --- a/examples/benchmarking/imagenet_v2.py +++ b/examples/benchmarking/imagenet_v2.py @@ -98,6 +98,12 @@ def preprocess_image(img, label): # model size, etc. loss, acc, top_5 = model.evaluate(test_set, verbose=0) print( - f"{FLAGS.model_name} achieves {acc} Top-1 Accuracy and {top_5} Top-5 Accuracy on ImageNetV2 with setup:" + f"Benchmark results:\n{'='*25}\n{FLAGS.model_name} achieves: \n - Top-1 Accuracy: {acc*100} \n - Top-5 Accuracy: {top_5*100} \non ImageNetV2 with setup:" +) +print( + f"- model_name: {FLAGS.model_name}\n" + f"- include_rescaling: {FLAGS.include_rescaling}\n" + f"- batch_size: {FLAGS.batch_size}\n" + f"- weights: {FLAGS.weights}\n" + f"- model_kwargs: {FLAGS.model_kwargs}\n" ) -print(FLAGS) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 3b43d8b0f2..24e36f5952 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -24,6 +24,7 @@ from keras_cv.layers import TransformerEncoder from keras_cv.layers.vit_layers import PatchingAndEmbedding from keras_cv.models import utils +from keras_cv.models.weights import parse_weights MODEL_CONFIGS = { "ViTTiny16": { @@ -138,15 +139,20 @@ For transfer learning use cases, make sure to read the [guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/). Args: - include_rescaling: whether or not to Rescale the inputs.If set to True, - inputs will be passed through a `Rescaling(1/255.0)` layer. + include_rescaling: whether or not to Rescale the inputs. If set to True, + inputs will be passed through a `Rescaling(scale=1./255.0)` layer. Note that ViTs + expect an input range of `[0..1]` if rescaling isn't used. Regardless of whether + you supply `[0..1]` or the input is rescaled to `[0..1]`, the inputs will further be + rescaled to `[-1..1]`. include_top: whether to include the fully-connected layer at the top of the network. If provided, classes must be provided. classes: optional number of classes to classify images into, only to be specified if `include_top` is True. weights: one of `None` (random initialization), a pretrained weight file path, or a reference to pre-trained weights (e.g. 'imagenet/classification') - (see available pre-trained weights in weights.py) + (see available pre-trained weights in weights.py). Note that the 'imagenet' + weights only work on an input shape of (224, 224, 3) due to the input shape dependent + patching and flattening logic. input_shape: optional shape tuple, defaults to (None, None, 3). input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. @@ -265,7 +271,11 @@ def ViT( x = inputs if include_rescaling: - x = layers.Rescaling(1 / 255.0)(x) + x = layers.Rescaling(1.0 / 255.0, name="rescaling")(x) + + # The previous layer rescales [0..255] to [0..1] if applicable + # This one rescales [0..1] to [-1..1] since ViTs expect [-1..1] + x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0, name="rescaling_2")(x) encoded_patches = PatchingAndEmbedding(project_dim, patch_size)(x) encoded_patches = layers.Dropout(mlp_dropout)(encoded_patches) @@ -292,6 +302,10 @@ def ViT( output = layers.GlobalAveragePooling1D()(output) model = keras.Model(inputs=inputs, outputs=output) + + if weights is not None: + model.load_weights(weights) + return model @@ -314,7 +328,7 @@ def ViTTiny16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vittiny16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -351,7 +365,7 @@ def ViTS16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vits16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -388,7 +402,7 @@ def ViTB16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitb16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -425,7 +439,7 @@ def ViTL16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitl16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -536,7 +550,7 @@ def ViTS32( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vits32"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -573,7 +587,7 @@ def ViTB32( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitb32"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index 5b458bf78c..b1a2b3725f 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -82,6 +82,30 @@ def parse_weights(weights, include_top, model_type): "imagenet": "imagenet/classification-v2", "imagenet/classification": "imagenet/classification-v2", }, + "vittiny16": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "vits16": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "vitb16": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "vitl16": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "vits32": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "vitb32": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, } WEIGHTS_CONFIG = { @@ -132,4 +156,28 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v2": "5ee5a8ac650aaa59342bc48ffe770e6797a5550bcc35961e1d06685292c15921", "imagenet/classification-v2-notop": "e711c83d6db7034871f6d345a476c8184eab99dbf3ffcec0c1d8445684890ad9", }, + "vittiny16": { + "imagenet/classification-v0": "c8227fde16ec8c2e7ab886169b11b4f0ca9af2696df6d16767db20acc9f6e0dd", + "imagenet/classification-v0-notop": "aa4d727e3c6bd30b20f49d3fa294fb4bbef97365c7dcb5cee9c527e4e83c8f5b", + }, + "vits16": { + "imagenet/classification-v0": "4a66a1a70a879ff33a3ca6ca30633b9eadafea84b421c92174557eee83e088b5", + "imagenet/classification-v0-notop": "8d0111eda6692096676a5453abfec5d04c79e2de184b04627b295f10b1949745", + }, + "vitb16": { + "imagenet/classification-v0": "6ab4e08c773e08de42023d963a97e905ccba710e2c05ef60c0971978d4a8c41b", + "imagenet/classification-v0-notop": "4a1bdd32889298471cb4f30882632e5744fd519bf1a1525b1fa312fe4ea775ed", + }, + "vitl16": { + "imagenet/classification-v0": "5a98000f848f2e813ea896b2528983d8d956f8c4b76ceed0b656219d5b34f7fb", + "imagenet/classification-v0-notop": "40d237c44f14d20337266fce6192c00c2f9b890a463fd7f4cb17e8e35b3f5448", + }, + "vits32": { + "imagenet/classification-v0": "f5836e3aff2bab202eaee01d98337a08258159d3b718e0421834e98b3665e10a", + "imagenet/classification-v0-notop": "f3907845eff780a4d29c1c56e0ae053411f02fff6fdce1147c4c3bb2124698cd", + }, + "vitb32": { + "imagenet/classification-v0": "73025caa78459dc8f9b1de7b58f1d64e24a823f170d17e25fcc8eb6179bea179", + "imagenet/classification-v0-notop": "f07b80c03336d731a2a3a02af5cac1e9fc9aa62659cd29e2e7e5c7474150cc71", + }, }