Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces Posterization preprocessing layer. #136

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_cv/layers/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras_cv.layers.preprocessing.grayscale import Grayscale
from keras_cv.layers.preprocessing.grid_mask import GridMask
from keras_cv.layers.preprocessing.mix_up import MixUp
from keras_cv.layers.preprocessing.posterization import Posterization
from keras_cv.layers.preprocessing.random_cutout import RandomCutout
from keras_cv.layers.preprocessing.random_shear import RandomShear
from keras_cv.layers.preprocessing.solarization import Solarization
103 changes: 103 additions & 0 deletions keras_cv/layers/preprocessing/posterization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from tensorflow.keras.__internal__.layers import BaseImageAugmentationLayer

from keras_cv.utils.preprocessing import transform_value_range


class Posterization(BaseImageAugmentationLayer):
"""Reduces the number of bits for each color channel.

References:
- [AutoAugment: Learning Augmentation Policies from Data](
https://arxiv.org/abs/1805.09501
)
- [RandAugment: Practical automated data augmentation with a reduced search space](
https://arxiv.org/abs/1909.13719
)

Args:
bits: integer. The number of bits to keep for each channel. Must be a value
between 1-8.
value_range: a tuple or a list of two elements. The first value represents
the lower bound for values in passed images, the second represents the
upper bound. Images passed to the layer should have values within
`value_range`. Defaults to `(0, 255)`.

Usage:
```python
(images, labels), _ = tf.keras.datasets.cifar10.load_data()
print(images[0, 0, 0])
# [59 62 63]
# Note that images are Tensors with values in the range [0, 255] and uint8 dtype
posterization = Posterization(bits=4, value_range=[0, 255])
images = posterization(images)
print(images[0, 0, 0])
# [48., 48., 48.]
# NOTE: the layer will output values in tf.float32, regardless of input dtype.
```

Call arguments:
inputs: input tensor in two possible formats:
1. single 3D (HWC) image or 4D (NHWC) batch of images.
2. A dict of tensors where the images are under `"images"` key.
"""

def __init__(self, bits: int, value_range=(0, 255), **kwargs):
super().__init__(**kwargs)

if not len(value_range) == 2:
raise ValueError(
"value_range must be a sequence of two elements. "
f"Received: {value_range}"
)

if not (0 < bits < 9):
raise ValueError(f"Bits value must be between 1-8. Received bits: {bits}.")

self._shift = 8 - bits
self._value_range = value_range

def augment_image(self, image, transformation=None):
image = transform_value_range(
images=image,
original_range=self._value_range,
target_range=[0, 255],
)
image = tf.cast(image, tf.uint8)

image = self._posterize(image)

image = tf.cast(image, self.compute_dtype)
return transform_value_range(
images=image,
original_range=[0, 255],
target_range=self._value_range,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets cast back to compute_dtype before transform_value_range. That way we are certain it is the correct dtype. transform_value_range can be skipped if value_range is [0, 255]

)

def _batch_augment(self, inputs):
# Skip the use of vectorized_map or map_fn as the implementation is already
# vectorized
return self._augment(inputs)

def _posterize(self, image):
return tf.bitwise.left_shift(
tf.bitwise.right_shift(image, self._shift), self._shift
)

def get_config(self):
config = {"bits": 8 - self.shift, "value_range": self._value_range}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
101 changes: 101 additions & 0 deletions keras_cv/layers/preprocessing/posterization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import tensorflow as tf

from keras_cv.layers.preprocessing.posterization import Posterization


class PosterizationTest(tf.test.TestCase):
rng = tf.random.Generator.from_non_deterministic_state()

def test_raises_error_on_invalid_bits_parameter(self):
invalid_values = [-1, 0, 9, 24]
for value in invalid_values:
with self.assertRaises(ValueError):
Posterization(bits=value, value_range=[0, 1])

def test_raises_error_on_invalid_value_range(self):
invalid_ranges = [(1,), [1, 2, 3]]
for value_range in invalid_ranges:
with self.assertRaises(ValueError):
Posterization(bits=1, value_range=value_range)

def test_single_image(self):
bits = self._get_random_bits()
dummy_input = self.rng.uniform(shape=(224, 224, 3), maxval=256)
expected_output = self._calc_expected_output(dummy_input, bits=bits)

layer = Posterization(bits=bits, value_range=[0, 255])
output = layer(dummy_input)

self.assertAllEqual(output, expected_output)

def _get_random_bits(self):
return int(self.rng.uniform(shape=(), minval=1, maxval=9, dtype=tf.int32))

def test_single_image_rescaled(self):
bits = self._get_random_bits()
dummy_input = self.rng.uniform(shape=(224, 224, 3), maxval=1.0)
expected_output = self._calc_expected_output(dummy_input * 255, bits=bits) / 255

layer = Posterization(bits=bits, value_range=[0, 1])
output = layer(dummy_input)

self.assertAllClose(output, expected_output)

def test_batched_input(self):
bits = self._get_random_bits()
dummy_input = self.rng.uniform(shape=(2, 224, 224, 3), maxval=256)

expected_output = []
for image in dummy_input:
expected_output.append(self._calc_expected_output(image, bits=bits))
expected_output = tf.stack(expected_output)

layer = Posterization(bits=bits, value_range=[0, 255])
output = layer(dummy_input)

self.assertAllEqual(output, expected_output)

def test_works_with_xla(self):
dummy_input = self.rng.uniform(shape=(2, 224, 224, 3))
layer = Posterization(bits=4, value_range=[0, 1])

@tf.function(jit_compile=True)
def apply(x):
return layer(x)

apply(dummy_input)

@staticmethod
def _calc_expected_output(image, bits):
"""Posterization in numpy, based on Albumentations:

The algorithm is basically:
1. create a lookup table of all possible input pixel values to pixel values
after posterize
2. map each pixel in the input to created lookup table.

Source:
https://github.com/albumentations-team/albumentations/blob/89a675cbfb2b76f6be90e7049cd5211cb08169a5/albumentations/augmentations/functional.py#L407
"""
dtype = image.dtype
image = tf.cast(image, tf.uint8)

lookup_table = np.arange(0, 256, dtype=np.uint8)
mask = ~np.uint8(2 ** (8 - bits) - 1)
lookup_table &= mask

return tf.cast(lookup_table[image], dtype)