From be3f9fed58ce1d4cf69966926912bc3cb9d07040 Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 15 Dec 2022 13:39:19 +0900 Subject: [PATCH 01/12] added weight loading --- keras_cv/models/vit.py | 18 ++++++---- keras_cv/models/weights.py | 73 ++++++++++++++++++++++++++------------ 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 3b43d8b0f2..80a27aa026 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -25,6 +25,8 @@ from keras_cv.layers.vit_layers import PatchingAndEmbedding from keras_cv.models import utils +from keras_cv.models.weights import parse_weights + MODEL_CONFIGS = { "ViTTiny16": { "patch_size": 16, @@ -292,6 +294,10 @@ def ViT( output = layers.GlobalAveragePooling1D()(output) model = keras.Model(inputs=inputs, outputs=output) + + if weights is not None: + model.load_weights(weights) + return model @@ -314,7 +320,7 @@ def ViTTiny16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vittiny16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -351,7 +357,7 @@ def ViTS16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vits16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -388,7 +394,7 @@ def ViTB16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitb16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -425,7 +431,7 @@ def ViTL16( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitl16"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -536,7 +542,7 @@ def ViTS32( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vits32"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, @@ -573,7 +579,7 @@ def ViTB32( include_rescaling, include_top, name=name, - weights=weights, + weights=parse_weights(weights, include_top, "vitb32"), input_shape=input_shape, input_tensor=input_tensor, pooling=pooling, diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index 45d53298ee..df9950fd8a 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -39,17 +39,6 @@ def parse_weights(weights, include_top, model_type): BASE_PATH = "https://storage.googleapis.com/keras-cv/models" ALIASES = { - "cspdarknet": { - "imagenet": "imagenet/classification-v0", - "imagenet/classification": "imagenet/classification-v0", - }, - "darknet53": { - "imagenet": "imagenet/classification-v0", - "imagenet/classification": "imagenet/classification-v0", - }, - "deeplabv3": { - "voc": "voc/segmentation-v0", - }, "densenet121": { "imagenet": "imagenet/classification-v0", "imagenet/classification": "imagenet/classification-v0", @@ -82,20 +71,33 @@ def parse_weights(weights, include_top, model_type): "imagenet": "imagenet/classification-v2", "imagenet/classification": "imagenet/classification-v2", }, -} - -WEIGHTS_CONFIG = { - "cspdarknet": { - "imagenet/classification-v0": "8bdc3359222f0d26f77aa42c4e97d67a05a1431fe6c448ceeab9a9c5a34ff804", - "imagenet/classification-v0-notop": "9303aabfadffbff8447171fce1e941f96d230d8f3cef30d3f05a9c85097f8f1e", + "vittiny16": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", + }, + "vits16": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", + }, + "vitb16": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", }, - "darknet53": { - "imagenet/classification-v0": "7bc5589f7f7f7ee3878e61ab9323a71682bfb617eb57f530ca8757c742f00c77", - "imagenet/classification-v0-notop": "8dcce43163e4b4a63e74330ba1902e520211db72d895b0b090b6bfe103e7a8a5", + "vitl16": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", }, - "deeplabv3": { - "voc/segmentation-v0": "732042e8b6c9ddba3d51c861f26dc41865187e9f85a0e5d43dfef75a405cca18", + "vits32": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", }, + "vitb32": { + "imagenet": "imagenet/classification-v2", + "imagenet/classification": "imagenet/classification-v2", + }, +} + +WEIGHTS_CONFIG = { "densenet121": { "imagenet/classification-v0": "13de3d077ad9d9816b9a0acc78215201d9b6e216c7ed8e71d69cc914f8f0775b", "imagenet/classification-v0-notop": "709afe0321d9f2b2562e562ff9d0dc44cca10ed09e0e2cfba08d783ff4dab6bf", @@ -132,4 +134,29 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v2": "5ee5a8ac650aaa59342bc48ffe770e6797a5550bcc35961e1d06685292c15921", "imagenet/classification-v2-notop": "e711c83d6db7034871f6d345a476c8184eab99dbf3ffcec0c1d8445684890ad9", }, -} + + "vittiny16": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, + "vits16": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, + "vitb16": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, + "vitl16": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, + "vits32": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, + "vitb32": { + "imagenet/classification-v0": "", + "imagenet/classification-v0-notop": "", + }, +} \ No newline at end of file From 999212d4fa6b09f38e873ea8b0a8bb837ea24905 Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 15 Dec 2022 13:39:42 +0900 Subject: [PATCH 02/12] formatting --- keras_cv/models/vit.py | 1 - keras_cv/models/weights.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 80a27aa026..e702400dc2 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -24,7 +24,6 @@ from keras_cv.layers import TransformerEncoder from keras_cv.layers.vit_layers import PatchingAndEmbedding from keras_cv.models import utils - from keras_cv.models.weights import parse_weights MODEL_CONFIGS = { diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index df9950fd8a..714a4f95f2 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -134,7 +134,6 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v2": "5ee5a8ac650aaa59342bc48ffe770e6797a5550bcc35961e1d06685292c15921", "imagenet/classification-v2-notop": "e711c83d6db7034871f6d345a476c8184eab99dbf3ffcec0c1d8445684890ad9", }, - "vittiny16": { "imagenet/classification-v0": "", "imagenet/classification-v0-notop": "", @@ -159,4 +158,4 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v0": "", "imagenet/classification-v0-notop": "", }, -} \ No newline at end of file +} From d292401a922b12499aa6f2d9f693e656db7415c0 Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 15 Dec 2022 13:50:26 +0900 Subject: [PATCH 03/12] update weights.py --- keras_cv/models/weights.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index 714a4f95f2..f42b97dfeb 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -98,6 +98,17 @@ def parse_weights(weights, include_top, model_type): } WEIGHTS_CONFIG = { + "cspdarknet": { + "imagenet/classification-v0": "8bdc3359222f0d26f77aa42c4e97d67a05a1431fe6c448ceeab9a9c5a34ff804", + "imagenet/classification-v0-notop": "9303aabfadffbff8447171fce1e941f96d230d8f3cef30d3f05a9c85097f8f1e", + }, + "darknet53": { + "imagenet/classification-v0": "7bc5589f7f7f7ee3878e61ab9323a71682bfb617eb57f530ca8757c742f00c77", + "imagenet/classification-v0-notop": "8dcce43163e4b4a63e74330ba1902e520211db72d895b0b090b6bfe103e7a8a5", + }, + "deeplabv3": { + "voc/segmentation-v0": "732042e8b6c9ddba3d51c861f26dc41865187e9f85a0e5d43dfef75a405cca18", + }, "densenet121": { "imagenet/classification-v0": "13de3d077ad9d9816b9a0acc78215201d9b6e216c7ed8e71d69cc914f8f0775b", "imagenet/classification-v0-notop": "709afe0321d9f2b2562e562ff9d0dc44cca10ed09e0e2cfba08d783ff4dab6bf", From 17056114a041bc5df5147fbf47eb95a71a64f32b Mon Sep 17 00:00:00 2001 From: David Landup Date: Sat, 17 Dec 2022 10:29:24 +0900 Subject: [PATCH 04/12] update --- keras_cv/models/vit.py | 4 ++- keras_cv/models/weights.py | 55 +++++++++++++++++++++++--------------- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index e702400dc2..44dd60c987 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -147,7 +147,9 @@ specified if `include_top` is True. weights: one of `None` (random initialization), a pretrained weight file path, or a reference to pre-trained weights (e.g. 'imagenet/classification') - (see available pre-trained weights in weights.py) + (see available pre-trained weights in weights.py). Note that the 'imagenet' + weights only work on an input shape of (224, 224, 3) due to the input shape dependent + patching and flattening logic. input_shape: optional shape tuple, defaults to (None, None, 3). input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to use as image input for the model. diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index f42b97dfeb..24e299cfab 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -39,6 +39,17 @@ def parse_weights(weights, include_top, model_type): BASE_PATH = "https://storage.googleapis.com/keras-cv/models" ALIASES = { + "cspdarknet": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "darknet53": { + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", + }, + "deeplabv3": { + "voc": "voc/segmentation-v0", + }, "densenet121": { "imagenet": "imagenet/classification-v0", "imagenet/classification": "imagenet/classification-v0", @@ -72,28 +83,28 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification": "imagenet/classification-v2", }, "vittiny16": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, "vits16": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, "vitb16": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, "vitl16": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, "vits32": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, "vitb32": { - "imagenet": "imagenet/classification-v2", - "imagenet/classification": "imagenet/classification-v2", + "imagenet": "imagenet/classification-v0", + "imagenet/classification": "imagenet/classification-v0", }, } @@ -146,27 +157,27 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v2-notop": "e711c83d6db7034871f6d345a476c8184eab99dbf3ffcec0c1d8445684890ad9", }, "vittiny16": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "c8227fde16ec8c2e7ab886169b11b4f0ca9af2696df6d16767db20acc9f6e0dd", + "imagenet/classification-v0-notop": "aa4d727e3c6bd30b20f49d3fa294fb4bbef97365c7dcb5cee9c527e4e83c8f5b", }, "vits16": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "4a66a1a70a879ff33a3ca6ca30633b9eadafea84b421c92174557eee83e088b5", + "imagenet/classification-v0-notop": "8d0111eda6692096676a5453abfec5d04c79e2de184b04627b295f10b1949745", }, "vitb16": { "imagenet/classification-v0": "", "imagenet/classification-v0-notop": "", }, "vitl16": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "5a98000f848f2e813ea896b2528983d8d956f8c4b76ceed0b656219d5b34f7fb", + "imagenet/classification-v0-notop": "40d237c44f14d20337266fce6192c00c2f9b890a463fd7f4cb17e8e35b3f5448", }, "vits32": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "f5836e3aff2bab202eaee01d98337a08258159d3b718e0421834e98b3665e10a", + "imagenet/classification-v0-notop": "f3907845eff780a4d29c1c56e0ae053411f02fff6fdce1147c4c3bb2124698cd", }, "vitb32": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "73025caa78459dc8f9b1de7b58f1d64e24a823f170d17e25fcc8eb6179bea179", + "imagenet/classification-v0-notop": "f07b80c03336d731a2a3a02af5cac1e9fc9aa62659cd29e2e7e5c7474150cc71", }, } From 401d3bc1c8ef4d34aa1a028076c363e094e7bc54 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 20 Dec 2022 12:12:10 +0900 Subject: [PATCH 05/12] b16 --- keras_cv/models/weights.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/weights.py b/keras_cv/models/weights.py index 24e299cfab..ba20beaa59 100644 --- a/keras_cv/models/weights.py +++ b/keras_cv/models/weights.py @@ -165,8 +165,8 @@ def parse_weights(weights, include_top, model_type): "imagenet/classification-v0-notop": "8d0111eda6692096676a5453abfec5d04c79e2de184b04627b295f10b1949745", }, "vitb16": { - "imagenet/classification-v0": "", - "imagenet/classification-v0-notop": "", + "imagenet/classification-v0": "6ab4e08c773e08de42023d963a97e905ccba710e2c05ef60c0971978d4a8c41b", + "imagenet/classification-v0-notop": "4a1bdd32889298471cb4f30882632e5744fd519bf1a1525b1fa312fe4ea775ed", }, "vitl16": { "imagenet/classification-v0": "5a98000f848f2e813ea896b2528983d8d956f8c4b76ceed0b656219d5b34f7fb", From 18c046124c9872fb73a26ce4a5c4a09d031ca3f3 Mon Sep 17 00:00:00 2001 From: David Landup Date: Tue, 20 Dec 2022 14:26:43 +0900 Subject: [PATCH 06/12] update benchmarking script --- examples/benchmarking/imagenet_v2.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/benchmarking/imagenet_v2.py b/examples/benchmarking/imagenet_v2.py index d2308960b7..15123d15d5 100644 --- a/examples/benchmarking/imagenet_v2.py +++ b/examples/benchmarking/imagenet_v2.py @@ -98,6 +98,12 @@ def preprocess_image(img, label): # model size, etc. loss, acc, top_5 = model.evaluate(test_set, verbose=0) print( - f"{FLAGS.model_name} achieves {acc} Top-1 Accuracy and {top_5} Top-5 Accuracy on ImageNetV2 with setup:" + f"{FLAGS.model_name} achieves: \n Top-1 Accuracy: {acc*100} \n Top-5 Accuracy: {top_5*100} \n on ImageNetV2 with setup:" +) +print( + f"model_name: {FLAGS.model_name}\n" + f"include_rescaling: {FLAGS.include_rescaling}\n" + f"batch_size: {FLAGS.batch_size}\n" + f"weights: {FLAGS.weights}\n" + f"model_kwargs: {FLAGS.model_kwargs}\n" ) -print(FLAGS) From 3dd59471d860778a8de638d19e7787f41a5b23b7 Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 22 Dec 2022 11:56:28 +0900 Subject: [PATCH 07/12] tuned rescaling --- keras_cv/models/vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 44dd60c987..801f4e3edb 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -140,7 +140,7 @@ learning & fine-tuning](https://keras.io/guides/transfer_learning/). Args: include_rescaling: whether or not to Rescale the inputs.If set to True, - inputs will be passed through a `Rescaling(1/255.0)` layer. + inputs will be passed through a `Rescaling(scale=1./127.5, offset=-1)` layer. include_top: whether to include the fully-connected layer at the top of the network. If provided, classes must be provided. classes: optional number of classes to classify images into, only to be @@ -268,7 +268,7 @@ def ViT( x = inputs if include_rescaling: - x = layers.Rescaling(1 / 255.0)(x) + x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(x) encoded_patches = PatchingAndEmbedding(project_dim, patch_size)(x) encoded_patches = layers.Dropout(mlp_dropout)(encoded_patches) From f5fd1cab3167f0e3ad6655539956072aecfbe31e Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 22 Dec 2022 12:09:26 +0900 Subject: [PATCH 08/12] improved formatting for benchmark script --- examples/benchmarking/imagenet_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmarking/imagenet_v2.py b/examples/benchmarking/imagenet_v2.py index 15123d15d5..c53d85ae20 100644 --- a/examples/benchmarking/imagenet_v2.py +++ b/examples/benchmarking/imagenet_v2.py @@ -98,7 +98,7 @@ def preprocess_image(img, label): # model size, etc. loss, acc, top_5 = model.evaluate(test_set, verbose=0) print( - f"{FLAGS.model_name} achieves: \n Top-1 Accuracy: {acc*100} \n Top-5 Accuracy: {top_5*100} \n on ImageNetV2 with setup:" + f"Benchmark results:\n{'='*25}\n{FLAGS.model_name} achieves: \n - Top-1 Accuracy: {acc*100} \n - Top-5 Accuracy: {top_5*100} \non ImageNetV2 with setup:" ) print( f"model_name: {FLAGS.model_name}\n" From daeb40431d6fc1960a524efcf4b3247a60643ec2 Mon Sep 17 00:00:00 2001 From: David Landup Date: Thu, 22 Dec 2022 12:10:27 +0900 Subject: [PATCH 09/12] improved formatting for benchmark script --- examples/benchmarking/imagenet_v2.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/benchmarking/imagenet_v2.py b/examples/benchmarking/imagenet_v2.py index c53d85ae20..865c74340b 100644 --- a/examples/benchmarking/imagenet_v2.py +++ b/examples/benchmarking/imagenet_v2.py @@ -101,9 +101,9 @@ def preprocess_image(img, label): f"Benchmark results:\n{'='*25}\n{FLAGS.model_name} achieves: \n - Top-1 Accuracy: {acc*100} \n - Top-5 Accuracy: {top_5*100} \non ImageNetV2 with setup:" ) print( - f"model_name: {FLAGS.model_name}\n" - f"include_rescaling: {FLAGS.include_rescaling}\n" - f"batch_size: {FLAGS.batch_size}\n" - f"weights: {FLAGS.weights}\n" - f"model_kwargs: {FLAGS.model_kwargs}\n" + f"- model_name: {FLAGS.model_name}\n" + f"- include_rescaling: {FLAGS.include_rescaling}\n" + f"- batch_size: {FLAGS.batch_size}\n" + f"- weights: {FLAGS.weights}\n" + f"- model_kwargs: {FLAGS.model_kwargs}\n" ) From f85387da39db239250489cf73125454640fafeec Mon Sep 17 00:00:00 2001 From: David Landup Date: Fri, 23 Dec 2022 08:19:12 +0900 Subject: [PATCH 10/12] standardized rescaling --- keras_cv/models/vit.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 801f4e3edb..1cb58988fe 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -268,7 +268,11 @@ def ViT( x = inputs if include_rescaling: - x = layers.Rescaling(scale=1.0 / 127.5, offset=-1)(x) + x = layers.Rescaling(1.0 / 255.0)(x) + + # The previous layer rescales [0..255] to [0..1] if applicable + # This one rescales [0..1] to [-1..1] since ViTs expect [-1..1] + x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0)(x) encoded_patches = PatchingAndEmbedding(project_dim, patch_size)(x) encoded_patches = layers.Dropout(mlp_dropout)(encoded_patches) From 61fa575e9509d25bf0b4b2d4d473540fec9a17da Mon Sep 17 00:00:00 2001 From: David Landup Date: Fri, 23 Dec 2022 08:28:32 +0900 Subject: [PATCH 11/12] updated docs --- keras_cv/models/vit.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 1cb58988fe..054ed1d759 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -139,8 +139,11 @@ For transfer learning use cases, make sure to read the [guide to transfer learning & fine-tuning](https://keras.io/guides/transfer_learning/). Args: - include_rescaling: whether or not to Rescale the inputs.If set to True, - inputs will be passed through a `Rescaling(scale=1./127.5, offset=-1)` layer. + include_rescaling: whether or not to Rescale the inputs. If set to True, + inputs will be passed through a `Rescaling(scale=1./255.0)` layer. Note that ViTs + expect an input range of `[0..1]` if rescaling isn't used. Regardless of whether + you supply `[0..1]` or the input is rescaled to `[0..1]`, the inputs will further be + rescaled to `[-1..1]`. include_top: whether to include the fully-connected layer at the top of the network. If provided, classes must be provided. classes: optional number of classes to classify images into, only to be From 2894dac81b02a404284ce0cb99b458960f9ad725 Mon Sep 17 00:00:00 2001 From: David Landup Date: Wed, 28 Dec 2022 11:28:56 +0900 Subject: [PATCH 12/12] layer names --- keras_cv/models/vit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras_cv/models/vit.py b/keras_cv/models/vit.py index 054ed1d759..24e36f5952 100644 --- a/keras_cv/models/vit.py +++ b/keras_cv/models/vit.py @@ -271,11 +271,11 @@ def ViT( x = inputs if include_rescaling: - x = layers.Rescaling(1.0 / 255.0)(x) + x = layers.Rescaling(1.0 / 255.0, name="rescaling")(x) # The previous layer rescales [0..255] to [0..1] if applicable # This one rescales [0..1] to [-1..1] since ViTs expect [-1..1] - x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0)(x) + x = layers.Rescaling(scale=1.0 / 0.5, offset=-1.0, name="rescaling_2")(x) encoded_patches = PatchingAndEmbedding(project_dim, patch_size)(x) encoded_patches = layers.Dropout(mlp_dropout)(encoded_patches)