Skip to content

Commit

Permalink
Implement RepeatedAugmentation as a KerasCV API (keras-team#1293)
Browse files Browse the repository at this point in the history
* Implement RepeatedAugmentation as a KerasCV API

more reading and fixes keras-team#372

* add test case

* fix formatting

* fix formatting

* fix formatting

* fix serialization test

* add repeated augmentation usage docstring

* Update component for repeated augment

* Repeated augmentations fix

* Test MixUp explicitly

* update docstring

* update docstring

* Reformat

* keras_cv/layers/preprocessing/repeated_augmentation.py
  • Loading branch information
LukeWood authored Jan 25, 2023
1 parent 42aafee commit dc70a69
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from keras_cv.layers.preprocessing.random_shear import RandomShear
from keras_cv.layers.preprocessing.random_zoom import RandomZoom
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation
from keras_cv.layers.preprocessing.rescaling import Rescaling
from keras_cv.layers.preprocessing.resizing import Resizing
from keras_cv.layers.preprocessing.solarization import Solarization
Expand Down
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from keras_cv.layers.preprocessing.random_shear import RandomShear
from keras_cv.layers.preprocessing.random_zoom import RandomZoom
from keras_cv.layers.preprocessing.randomly_zoomed_crop import RandomlyZoomedCrop
from keras_cv.layers.preprocessing.repeated_augmentation import RepeatedAugmentation
from keras_cv.layers.preprocessing.rescaling import Rescaling
from keras_cv.layers.preprocessing.resizing import Resizing
from keras_cv.layers.preprocessing.solarization import Solarization
113 changes: 113 additions & 0 deletions keras_cv/layers/preprocessing/repeated_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf

from keras_cv.layers.preprocessing.base_image_augmentation_layer import (
BaseImageAugmentationLayer,
)


@tf.keras.utils.register_keras_serializable(package="keras_cv")
class RepeatedAugmentation(BaseImageAugmentationLayer):
"""RepeatedAugmentation augments each image in a batch multiple times.
This technique exists to emulate the behavior of stochastic gradient descent within
the context of mini-batch gradient descent. When training large vision models,
choosing a large batch size can introduce too much noise into aggregated gradients
causing the overall batch's gradients to be less effective than gradients produced
using smaller gradients. RepeatedAugmentation handles this by re-using the same
image multiple times within a batch creating correlated samples.
This layer increases your batch size by a factor of `len(augmenters)`.
Args:
augmenters: the augmenters to use to augment the image
shuffle: whether or not to shuffle the result. Essential when using an
asynchronous distribution strategy such as ParameterServerStrategy.
Usage:
List of identical augmenters:
```python
repeated_augment = cv_layers.RepeatedAugmentation(
augmenters=[cv_layers.RandAugment(value_range=(0, 255))] * 8
)
inputs = {
"images": tf.ones((8, 512, 512, 3)),
"labels": tf.ones((8,)),
}
outputs = repeated_augment(inputs)
# outputs now has a batch size of 64 because there are 8 augmenters
```
List of distinct augmenters:
```python
repeated_augment = cv_layers.RepeatedAugmentation(
augmenters=[
cv_layers.RandAugment(value_range=(0, 255)),
cv_layers.RandomFlip(),
]
)
inputs = {
"images": tf.ones((8, 512, 512, 3)),
"labels": tf.ones((8,)),
}
outputs = repeated_augment(inputs)
```
References:
- [DEIT implementaton](https://github.com/facebookresearch/deit/blob/ee8893c8063f6937fec7096e47ba324c206e22b9/samplers.py#L8)
- [Original publication](https://openaccess.thecvf.com/content_CVPR_2020/papers/Hoffer_Augment_Your_Batch_Improving_Generalization_Through_Instance_Repetition_CVPR_2020_paper.pdf)
"""

def __init__(self, augmenters, shuffle=True, **kwargs):
super().__init__(**kwargs)
self.augmenters = augmenters
self.shuffle = shuffle

def _batch_augment(self, inputs):
if "bounding_boxes" in inputs:
raise ValueError(
"RepeatedAugmentation() does not yet support bounding box labels."
)

augmenter_outputs = [augmenter(inputs) for augmenter in self.augmenters]

outputs = {}
for k in inputs.keys():
outputs[k] = tf.concat([output[k] for output in augmenter_outputs], axis=0)

if not self.shuffle:
return outputs
return self.shuffle_outputs(outputs)

def shuffle_outputs(self, result):
indices = tf.range(start=0, limit=tf.shape(result["images"])[0], dtype=tf.int32)
indices = tf.random.shuffle(indices)
for key in result:
result[key] = tf.gather(result[key], indices)
return result

def _augment(self, inputs):
raise ValueError(
"RepeatedAugmentation() only works in batched mode. If "
"you would like to create batches from a single image, use "
"`x = tf.expand_dims(x, axis=0)` on your input images and labels."
)

def get_config(self):
config = super().get_config()
config.update({"augmenters": self.augmenters, "shuffle": self.shuffle})
return config
50 changes: 50 additions & 0 deletions keras_cv/layers/preprocessing/repeated_augmentation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf

import keras_cv.layers as cv_layers


class RepeatedAugmentationTest(tf.test.TestCase):
def test_output_shapes(self):
repeated_augment = cv_layers.RepeatedAugmentation(
augmenters=[
cv_layers.RandAugment(value_range=(0, 255)),
cv_layers.RandomFlip(),
]
)
inputs = {
"images": tf.ones((8, 512, 512, 3)),
"labels": tf.ones((8,)),
}
outputs = repeated_augment(inputs)

self.assertEqual(outputs["images"].shape, (16, 512, 512, 3))
self.assertEqual(outputs["labels"].shape, (16,))

def test_with_mix_up(self):
repeated_augment = cv_layers.RepeatedAugmentation(
augmenters=[
cv_layers.RandAugment(value_range=(0, 255)),
cv_layers.MixUp(),
]
)
inputs = {
"images": tf.ones((8, 512, 512, 3)),
"labels": tf.ones((8, 10)),
}
outputs = repeated_augment(inputs)

self.assertEqual(outputs["images"].shape, (16, 512, 512, 3))
self.assertEqual(outputs["labels"].shape, (16, 10))
10 changes: 10 additions & 0 deletions keras_cv/layers/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ class SerializationTest(tf.test.TestCase, parameterized.TestCase):
("GridMask", cv_layers.GridMask, {"seed": 1}),
("MixUp", cv_layers.MixUp, {"seed": 1}),
("Mosaic", cv_layers.Mosaic, {"seed": 1}),
(
"RepeatedAugmentation",
cv_layers.RepeatedAugmentation,
{
"augmenters": [
cv_layers.RandAugment(value_range=(0, 1)),
cv_layers.RandomFlip(),
]
},
),
(
"RandomChannelShift",
cv_layers.RandomChannelShift,
Expand Down

0 comments on commit dc70a69

Please sign in to comment.