Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduces Posterization preprocessing layer. #136

Conversation

sebastian-sz
Copy link
Contributor

Posterization is one of the operations used by AutoAugment and RandAugment.

Linked Issue


self.assertEqual(output.shape, dummy_input.shape)

def test_output_dtype_unchanged(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the same also on this one.

[batch, height, width, channels] or [height, width, channels].
"""

def __init__(self, bits: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there - please include value_range as an argument and scale images up to [0, 255] according to value_range, then scale back down to value_range before returning. Thank.

@LukeWood
Copy link
Contributor

Thank you for the PR. Left a starting comment, will do a thorough review. Please be aware we are doing some major design work around KPLs in KerasCV and the API may shift as a result. Apologies if this delays reviews.

@sebastian-sz
Copy link
Contributor Author

Thank you for the PR. Left a starting comment, will do a thorough review. Please be aware we are doing some major design work around KPLs in KerasCV and the API may shift as a result. Apologies if this delays reviews.

Hi, no worries. Let me know if you're done with refactoring and design changes, I will update this PR accordingly.

@sebastian-sz
Copy link
Contributor Author

@LukeWood I wanted to update the layer to subclass BaseImageAugmentationLayer, but it's not available in TF 2.8. Should I wait for 2.9?

@bhack
Copy link
Contributor

bhack commented Mar 7, 2022

@LukeWood I wanted to update the layer to subclass BaseImageAugmentationLayer, but it's not available in TF 2.8. Should I wait for 2.9?

It seems that we rely on Tensorflow stable in this repo but in Keras repo on master we track TF Nightly:
https://github.com/keras-team/keras-cv/blob/master/setup.py#L27

@LukeWood
Copy link
Contributor

LukeWood commented Mar 7, 2022

@LukeWood I wanted to update the layer to subclass BaseImageAugmentationLayer, but it's not available in TF 2.8. Should I wait for 2.9?

It seems that we rely on Tensorflow stable in this repo but in Keras repo on master we track TF Nightly: https://github.com/keras-team/keras-cv/blob/master/setup.py#L27

I'll update the repo to rely on nightly until 2.9. We're in pre-release so this is fine for now.

@LukeWood
Copy link
Contributor

BaseImageAugmentationLayer

Yes, sorry for the delayed response.

@sebastian-sz
Copy link
Contributor Author

@LukeWood I'm starting to use BaseImageAugmentationLayer. I noticed it force-casts all inputs to tf.float32 - is this intended?

import tensorflow as tf
from tensorflow.keras.__internal__.layers import BaseImageAugmentationLayer


class DummyLayer(BaseImageAugmentationLayer):
    def augment_image(self, image, transformation=None):
        return image 

layer = DummyLayer()
image = tf.random.uniform((1, 224, 224, 3),dtype=tf.int32)
output = layer(image)
print(output.dtype)  # tf.float32, but input was tf.int32

@LukeWood
Copy link
Contributor

The layer will cast your inputs to self.compute_dtype. This is pretty standard for keras.

@LukeWood
Copy link
Contributor

Basically @sebastian-sz to ensure consistency across dtypes we will just always assume that the inputs are float, and make no assumptions about their value range.

@LukeWood
Copy link
Contributor

Hey @sebastian-sz a sample usage of BaseImageAugmentationLayer is available in #180.

Feel free to base yours on this. Upload when a copy is ready and I will review :)

@sebastian-sz
Copy link
Contributor Author

Thanks @LukeWood I'm almost done with updating the code but I'm still unhappy with how the tests look.

This is in progress and will probably be done by the end of the week (probably sooner).

@LukeWood
Copy link
Contributor

awesome - yeah feel free to upload as you go. I may begin lending a hand in some of these components as the list of components required for RandAugment shrinks.

@sebastian-sz sebastian-sz force-pushed the feature-84/introduce-posterization-layer branch from 91b83ce to e598294 Compare March 23, 2022 06:35
@sebastian-sz
Copy link
Contributor Author

@LukeWood I updated the PR. Two questions:

  1. What is the tf.function policy? Should the augment_image or any other function be decorated?
  2. Is it ok to override _batch_augment as well if the implementation is already vectorized and does not need to rely on vectorized_map / map_fn ? See code for more details, if not I will remove this part.

@LukeWood
Copy link
Contributor

@LukeWood I updated the PR. Two questions:

  1. What is the tf.function policy? Should the augment_image or any other function be decorated?
  2. Is it ok to override _batch_augment as well if the implementation is already vectorized and does not need to rely on vectorized_map / map_fn ? See code for more details, if not I will remove this part.

1.) no need for tf.function
2.) I think it is actually fine. I will check with @qlzh727 too.

return transform_value_range(
images=image,
original_range=[0, 255],
target_range=self._value_range,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets cast back to compute_dtype before transform_value_range. That way we are certain it is the correct dtype. transform_value_range can be skipped if value_range is [0, 255]

@LukeWood
Copy link
Contributor

@LukeWood I updated the PR. Two questions:

  1. What is the tf.function policy? Should the augment_image or any other function be decorated?
  2. Is it ok to override _batch_augment as well if the implementation is already vectorized and does not need to rely on vectorized_map / map_fn ? See code for more details, if not I will remove this part.

1.) no need for tf.function 2.) I think it is actually fine. I will check with @qlzh727 too.

Discussed with Scott - it is fine for now. Thanks! Good idea.

@bhack
Copy link
Contributor

bhack commented Mar 23, 2022

1.) no need for tf.function

The model logic is wrapped by default with a tf.function until you will go to explicitly enforce the eager mode like in model compilation

jit_compile could be different instead as we could need to limit the performance region to a specific function. See more at #165

@LukeWood
Copy link
Contributor

Thanks a ton @sebastian-sz - great contribution! Looks good to go!

@LukeWood LukeWood merged commit 1f7bde4 into keras-team:master Mar 23, 2022
@sebastian-sz sebastian-sz deleted the feature-84/introduce-posterization-layer branch March 30, 2022 08:59
ianstenbit pushed a commit to ianstenbit/keras-cv that referenced this pull request Aug 6, 2022
* Added Posterization preprocessing layer.

* Changed Posterization's AssertionError to ValueError.

* Refactored Posterization to use BaseImageAugmentationLayer.

* Copyeditst

Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
adhadse pushed a commit to adhadse/keras-cv that referenced this pull request Sep 17, 2022
* Added Posterization preprocessing layer.

* Changed Posterization's AssertionError to ValueError.

* Refactored Posterization to use BaseImageAugmentationLayer.

* Copyeditst

Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants