-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduces Posterization preprocessing layer. #136
Introduces Posterization preprocessing layer. #136
Conversation
|
||
self.assertEqual(output.shape, dummy_input.shape) | ||
|
||
def test_output_dtype_unchanged(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably the same also on this one.
[batch, height, width, channels] or [height, width, channels]. | ||
""" | ||
|
||
def __init__(self, bits: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - please include value_range
as an argument and scale images up to [0, 255] according to value_range, then scale back down to value_range
before returning. Thank.
Thank you for the PR. Left a starting comment, will do a thorough review. Please be aware we are doing some major design work around KPLs in KerasCV and the API may shift as a result. Apologies if this delays reviews. |
Hi, no worries. Let me know if you're done with refactoring and design changes, I will update this PR accordingly. |
@LukeWood I wanted to update the layer to subclass |
It seems that we rely on Tensorflow stable in this repo but in Keras repo on master we track TF Nightly: |
I'll update the repo to rely on nightly until 2.9. We're in pre-release so this is fine for now. |
Yes, sorry for the delayed response. |
@LukeWood I'm starting to use import tensorflow as tf
from tensorflow.keras.__internal__.layers import BaseImageAugmentationLayer
class DummyLayer(BaseImageAugmentationLayer):
def augment_image(self, image, transformation=None):
return image
layer = DummyLayer()
image = tf.random.uniform((1, 224, 224, 3),dtype=tf.int32)
output = layer(image)
print(output.dtype) # tf.float32, but input was tf.int32 |
The layer will cast your inputs to |
Basically @sebastian-sz to ensure consistency across dtypes we will just always assume that the inputs are float, and make no assumptions about their value range. |
Hey @sebastian-sz a sample usage of BaseImageAugmentationLayer is available in #180. Feel free to base yours on this. Upload when a copy is ready and I will review :) |
Thanks @LukeWood I'm almost done with updating the code but I'm still unhappy with how the tests look. This is in progress and will probably be done by the end of the week (probably sooner). |
awesome - yeah feel free to upload as you go. I may begin lending a hand in some of these components as the list of components required for RandAugment shrinks. |
91b83ce
to
e598294
Compare
@LukeWood I updated the PR. Two questions:
|
1.) no need for tf.function |
return transform_value_range( | ||
images=image, | ||
original_range=[0, 255], | ||
target_range=self._value_range, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets cast back to compute_dtype before transform_value_range
. That way we are certain it is the correct dtype. transform_value_range can be skipped if value_range is [0, 255]
Discussed with Scott - it is fine for now. Thanks! Good idea. |
The model logic is wrapped by default with a
|
Thanks a ton @sebastian-sz - great contribution! Looks good to go! |
* Added Posterization preprocessing layer. * Changed Posterization's AssertionError to ValueError. * Refactored Posterization to use BaseImageAugmentationLayer. * Copyeditst Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
* Added Posterization preprocessing layer. * Changed Posterization's AssertionError to ValueError. * Refactored Posterization to use BaseImageAugmentationLayer. * Copyeditst Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
Posterization is one of the operations used by AutoAugment and RandAugment.
Linked Issue