Skip to content

Commit

Permalink
ref(examples): Refactored example helper functions (keras-team#405)
Browse files Browse the repository at this point in the history
* ref(examples): Refactored example helper functions

* ref(examples): remove apply_aug

* ref(examples): revert back to map in hue example

* Update demo utilities

* Format

Co-authored-by: Luke Wood <lukewoodcs@gmail.com>
  • Loading branch information
(Ian Stenbit) and LukeWood committed Jun 11, 2022
1 parent c437041 commit e2795b9
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 133 deletions.
42 changes: 6 additions & 36 deletions examples/layers/preprocessing/channel_shuffle_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,17 @@
Finally, they are shown using matplotlib.
"""

import matplotlib.pyplot as plt
import demo_utils
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_cv.layers import preprocessing

IMG_SIZE = (224, 224)
BATCH_SIZE = 64


def resize(image, label, num_classes=10):
image = tf.image.resize(image, IMG_SIZE)
label = tf.one_hot(label, num_classes)
return image, label
from keras_cv import layers


def main():
data, ds_info = tfds.load("oxford_flowers102", with_info=True, as_supervised=True)
train_ds = data["train"]

num_classes = ds_info.features["label"].num_classes

train_ds = (
train_ds.map(lambda x, y: resize(x, y, num_classes=num_classes))
.shuffle(10 * BATCH_SIZE)
.batch(BATCH_SIZE)
)

channel_shuffle = preprocessing.ChannelShuffle()
train_ds = train_ds.map(
lambda x, y: (channel_shuffle(x, training=True), y),
num_parallel_calls=tf.data.AUTOTUNE,
)

for images, labels in train_ds.take(1):
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()
channel_shuffle = layers.ChannelShuffle()
ds = demo_utils.load_oxford_dataset()
ds = ds.map(channel_shuffle, num_parallel_calls=tf.data.AUTOTUNE)
demo_utils.visualize_dataset(ds)


if __name__ == "__main__":
Expand Down
43 changes: 7 additions & 36 deletions examples/layers/preprocessing/cut_mix_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,22 @@
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_cv.layers import preprocessing

IMG_SIZE = (224, 224)
BATCH_SIZE = 64

import demo_utils
import tensorflow as tf

def resize(image, label, num_classes=10):
image = tf.image.resize(image, IMG_SIZE)
label = tf.one_hot(label, num_classes)
return image, label
from keras_cv import layers


def to_dict(images, labels):
return {"images": images, "labels": labels}


def main():
data, ds_info = tfds.load("oxford_flowers102", with_info=True, as_supervised=True)
train_ds = data["train"]

num_classes = ds_info.features["label"].num_classes

train_ds = (
train_ds.map(lambda x, y: resize(x, y, num_classes=num_classes))
.shuffle(10 * BATCH_SIZE)
.batch(BATCH_SIZE)
.map(to_dict)
)
cutmix = preprocessing.CutMix()
train_ds = train_ds.map(
lambda image, label: cutmix({"images": image, "labels": label}),
num_parallel_calls=tf.data.AUTOTUNE,
)

for data in train_ds.take(1):
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(data["images"][i].numpy().astype("uint8"))
plt.axis("off")
plt.show()
cutmix = layers.CutMix()
ds = demo_utils.load_oxford_dataset()
ds = ds.map(cutmix, num_parallel_calls=tf.data.AUTOTUNE)
demo_utils.visualize_dataset(ds)


if __name__ == "__main__":
Expand Down
52 changes: 52 additions & 0 deletions examples/layers/preprocessing/demo_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2022 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for preprocessing demos."""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds


def resize(image, label, img_size=(224, 224), num_classes=10):
image = tf.image.resize(image, img_size)
label = tf.one_hot(label, num_classes)
return {"images": image, "labels": label}


def load_oxford_dataset(
name="oxford_flowers102",
batch_size=64,
img_size=(224, 224),
as_supervised=True,
):
# Load dataset.
data, ds_info = tfds.load(name, as_supervised=as_supervised, with_info=True)
train_ds = data["train"]
num_classes = ds_info.features["label"].num_classes

# Get tf dataset.
train_ds = train_ds.map(
lambda x, y: resize(x, y, img_size=img_size, num_classes=num_classes)
).batch(batch_size)
return train_ds


def visualize_dataset(ds):
outputs = next(iter(ds.take(1)))
images = outputs["images"]
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()
43 changes: 7 additions & 36 deletions examples/layers/preprocessing/fourier_mix_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,18 @@
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from keras_cv.layers import preprocessing

IMG_SIZE = (224, 224)
BATCH_SIZE = 64

import demo_utils
import tensorflow as tf

def resize(image, label, num_classes=10):
image = tf.image.resize(image, IMG_SIZE)
label = tf.one_hot(label, num_classes)
return image, label
from keras_cv import layers


def main():
data, ds_info = tfds.load("oxford_flowers102", with_info=True, as_supervised=True)
train_ds = data["train"]

num_classes = ds_info.features["label"].num_classes

train_ds = (
train_ds.map(lambda x, y: resize(x, y, num_classes=num_classes))
.shuffle(10 * BATCH_SIZE)
.batch(BATCH_SIZE)
)
fourier_mix = preprocessing.FourierMix(alpha=0.5)
train_ds = train_ds.map(
lambda x, y: fourier_mix({"images": x, "labels": y}),
num_parallel_calls=tf.data.AUTOTUNE,
)

for batch in train_ds.take(1):
images = batch["images"]
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()
fourier_mix = layers.FourierMix(alpha=0.5)
ds = demo_utils.load_oxford_dataset()
ds = ds.map(fourier_mix, num_parallel_calls=tf.data.AUTOTUNE)
demo_utils.visualize_dataset(ds)


if __name__ == "__main__":
Expand Down
35 changes: 10 additions & 25 deletions examples/layers/preprocessing/random_hue_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,23 @@
are loaded, then are passed through the preprocessing layers.
Finally, they are shown using matplotlib.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import utils

from keras_cv.layers import preprocessing

IMG_SIZE = (224, 224)
BATCH_SIZE = 64


def resize(image, label):
image = tf.image.resize(image, IMG_SIZE)
return image, label


def main():
data, ds_info = tfds.load("oxford_flowers102", with_info=True, as_supervised=True)
train_ds = data["train"]
# Prepare flower dataset dataset.
train_ds = utils.prepare_dataset()

train_ds = train_ds.map(lambda x, y: resize(x, y)).batch(BATCH_SIZE)
# Prepare augmentation layer.
random_hue = preprocessing.RandomHue(factor=(0.0, 1.0), value_range=(0, 255))
train_ds = train_ds.map(
lambda x, y: (random_hue(x), y), num_parallel_calls=tf.data.AUTOTUNE
)

for images, labels in train_ds.take(1):
plt.figure(figsize=(8, 8))
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.axis("off")
plt.show()

# Apply augmentation.
train_ds = train_ds.map(lambda x, y: (random_hue(x), y))

# visualize.
utils.visualize_dataset(train_ds)


if __name__ == "__main__":
Expand Down

0 comments on commit e2795b9

Please sign in to comment.