From dc70a698b7024ddcb3cdaf62431099be43c2a6e2 Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Tue, 24 Jan 2023 19:36:01 -0500 Subject: [PATCH] Implement RepeatedAugmentation as a KerasCV API (#1293) * Implement RepeatedAugmentation as a KerasCV API more reading and fixes https://github.com/keras-team/keras-cv/issues/372 * add test case * fix formatting * fix formatting * fix formatting * fix serialization test * add repeated augmentation usage docstring * Update component for repeated augment * Repeated augmentations fix * Test MixUp explicitly * update docstring * update docstring * Reformat * keras_cv/layers/preprocessing/repeated_augmentation.py --- keras_cv/layers/__init__.py | 1 + keras_cv/layers/preprocessing/__init__.py | 1 + .../preprocessing/repeated_augmentation.py | 113 ++++++++++++++++++ .../repeated_augmentation_test.py | 50 ++++++++ keras_cv/layers/serialization_test.py | 10 ++ 5 files changed, 175 insertions(+) create mode 100644 keras_cv/layers/preprocessing/repeated_augmentation.py create mode 100644 keras_cv/layers/preprocessing/repeated_augmentation_test.py diff --git a/keras_cv/layers/__init__.py b/keras_cv/layers/__init__.py index 0350616c56..910cfd9373 100644 --- a/keras_cv/layers/__init__.py +++ b/keras_cv/layers/__init__.py @@ -71,6 +71,7 @@ from keras_cv.layers.preprocessing.random_shear import RandomShear from keras_cv.layers.preprocessing.random_zoom import RandomZoom from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop +from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation from keras_cv.layers.preprocessing.rescaling import Rescaling from keras_cv.layers.preprocessing.resizing import Resizing from keras_cv.layers.preprocessing.solarization import Solarization diff --git a/keras_cv/layers/preprocessing/__init__.py b/keras_cv/layers/preprocessing/__init__.py index 9a060ec359..2c25bd815c 100644 --- a/keras_cv/layers/preprocessing/__init__.py +++ b/keras_cv/layers/preprocessing/__init__.py @@ -61,6 +61,7 @@ from keras_cv.layers.preprocessing.random_shear import RandomShear from keras_cv.layers.preprocessing.random_zoom import RandomZoom from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop +from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation from keras_cv.layers.preprocessing.rescaling import Rescaling from keras_cv.layers.preprocessing.resizing import Resizing from keras_cv.layers.preprocessing.solarization import Solarization diff --git a/keras_cv/layers/preprocessing/repeated_augmentation.py b/keras_cv/layers/preprocessing/repeated_augmentation.py new file mode 100644 index 0000000000..25fefc727a --- /dev/null +++ b/keras_cv/layers/preprocessing/repeated_augmentation.py @@ -0,0 +1,113 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + +from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( + BaseImageAugmentationLayer, +) + + +@tf.keras.utils.register_keras_serializable(package="keras_cv") +class RepeatedAugmentation(BaseImageAugmentationLayer): + """RepeatedAugmentation augments each image in a batch multiple times. + + This technique exists to emulate the behavior of stochastic gradient descent within + the context of mini-batch gradient descent. When training large vision models, + choosing a large batch size can introduce too much noise into aggregated gradients + causing the overall batch's gradients to be less effective than gradients produced + using smaller gradients. RepeatedAugmentation handles this by re-using the same + image multiple times within a batch creating correlated samples. + + This layer increases your batch size by a factor of `len(augmenters)`. + + Args: + augmenters: the augmenters to use to augment the image + shuffle: whether or not to shuffle the result. Essential when using an + asynchronous distribution strategy such as ParameterServerStrategy. + + Usage: + + List of identical augmenters: + ```python + repeated_augment = cv_layers.RepeatedAugmentation( + augmenters=[cv_layers.RandAugment(value_range=(0, 255))] * 8 + ) + inputs = { + "images": tf.ones((8, 512, 512, 3)), + "labels": tf.ones((8,)), + } + outputs = repeated_augment(inputs) + # outputs now has a batch size of 64 because there are 8 augmenters + ``` + + List of distinct augmenters: + ```python + repeated_augment = cv_layers.RepeatedAugmentation( + augmenters=[ + cv_layers.RandAugment(value_range=(0, 255)), + cv_layers.RandomFlip(), + ] + ) + inputs = { + "images": tf.ones((8, 512, 512, 3)), + "labels": tf.ones((8,)), + } + outputs = repeated_augment(inputs) + ``` + + References: + - [DEIT implementaton](https://github.com/facebookresearch/deit/blob/ee8893c8063f6937fec7096e47ba324c206e22b9/samplers.py#L8) + - [Original publication](https://openaccess.thecvf.com/content_CVPR_2020/papers/Hoffer_Augment_Your_Batch_Improving_Generalization_Through_Instance_Repetition_CVPR_2020_paper.pdf) + + """ + + def __init__(self, augmenters, shuffle=True, **kwargs): + super().__init__(**kwargs) + self.augmenters = augmenters + self.shuffle = shuffle + + def _batch_augment(self, inputs): + if "bounding_boxes" in inputs: + raise ValueError( + "RepeatedAugmentation() does not yet support bounding box labels." + ) + + augmenter_outputs = [augmenter(inputs) for augmenter in self.augmenters] + + outputs = {} + for k in inputs.keys(): + outputs[k] = tf.concat([output[k] for output in augmenter_outputs], axis=0) + + if not self.shuffle: + return outputs + return self.shuffle_outputs(outputs) + + def shuffle_outputs(self, result): + indices = tf.range(start=0, limit=tf.shape(result["images"])[0], dtype=tf.int32) + indices = tf.random.shuffle(indices) + for key in result: + result[key] = tf.gather(result[key], indices) + return result + + def _augment(self, inputs): + raise ValueError( + "RepeatedAugmentation() only works in batched mode. If " + "you would like to create batches from a single image, use " + "`x = tf.expand_dims(x, axis=0)` on your input images and labels." + ) + + def get_config(self): + config = super().get_config() + config.update({"augmenters": self.augmenters, "shuffle": self.shuffle}) + return config diff --git a/keras_cv/layers/preprocessing/repeated_augmentation_test.py b/keras_cv/layers/preprocessing/repeated_augmentation_test.py new file mode 100644 index 0000000000..120e1a8afa --- /dev/null +++ b/keras_cv/layers/preprocessing/repeated_augmentation_test.py @@ -0,0 +1,50 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tensorflow as tf + +import keras_cv.layers as cv_layers + + +class RepeatedAugmentationTest(tf.test.TestCase): + def test_output_shapes(self): + repeated_augment = cv_layers.RepeatedAugmentation( + augmenters=[ + cv_layers.RandAugment(value_range=(0, 255)), + cv_layers.RandomFlip(), + ] + ) + inputs = { + "images": tf.ones((8, 512, 512, 3)), + "labels": tf.ones((8,)), + } + outputs = repeated_augment(inputs) + + self.assertEqual(outputs["images"].shape, (16, 512, 512, 3)) + self.assertEqual(outputs["labels"].shape, (16,)) + + def test_with_mix_up(self): + repeated_augment = cv_layers.RepeatedAugmentation( + augmenters=[ + cv_layers.RandAugment(value_range=(0, 255)), + cv_layers.MixUp(), + ] + ) + inputs = { + "images": tf.ones((8, 512, 512, 3)), + "labels": tf.ones((8, 10)), + } + outputs = repeated_augment(inputs) + + self.assertEqual(outputs["images"].shape, (16, 512, 512, 3)) + self.assertEqual(outputs["labels"].shape, (16, 10)) diff --git a/keras_cv/layers/serialization_test.py b/keras_cv/layers/serialization_test.py index e60d40822c..7640a8ba51 100644 --- a/keras_cv/layers/serialization_test.py +++ b/keras_cv/layers/serialization_test.py @@ -96,6 +96,16 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase): ("GridMask", cv_layers.GridMask, {"seed": 1}), ("MixUp", cv_layers.MixUp, {"seed": 1}), ("Mosaic", cv_layers.Mosaic, {"seed": 1}), + ( + "RepeatedAugmentation", + cv_layers.RepeatedAugmentation, + { + "augmenters": [ + cv_layers.RandAugment(value_range=(0, 1)), + cv_layers.RandomFlip(), + ] + }, + ), ( "RandomChannelShift", cv_layers.RandomChannelShift,