-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vectorized_map
causes tf.function
retracing.
#241
Comments
Improving the performance of I wanted to point out that multiple |
I don't think we are going to be impacted by the autovectorizzation/retracing with real use cases: import tensorflow as tf
from keras_cv.layers.preprocessing import Solarization
from tensorflow.keras.models import Sequential
layer = Solarization() # or Equaliztion()
rng = tf.random.Generator.from_seed(1234)
from random import randint
model = Sequential()
model.add(layer)
model.build([24,224,224,3])
for x in range(50):
x = rng.uniform(
shape=(24,224, 224, 3), minval=0, maxval=255, dtype=tf.float32)
_ = model.predict(x) See my comments in tensorflow/tensorflow#42441 |
Thanks for the detailed report @sebastian-sz FYI @qlzh727 |
@bhack fair point - using Sequential and This is however a bit slower than calling the layer directly with Also, import time
import tensorflow as tf
from keras_cv.layers import Solarization
model = tf.keras.Sequential()
model.add(Solarization())
model.build([24, 224, 224, 3])
rng = tf.random.Generator.from_seed(1234)
ds = tf.data.Dataset.from_tensor_slices([rng.uniform(shape=(24, 224, 224, 3), maxval=256)]).repeat(100)
ds = ds.map(lambda x: model(x))
for _ in ds:
continue
start = time.perf_counter()
for _ in ds:
continue
stop = time.perf_counter()
print((stop - start) / 100) |
It seems like wrapping the entire layer in @tf.function(jit_compile=True)
def apply(x):
return layer(x) 0.0015ms for This issue can be closed from my end. If no further comments appear I will close this issue starting next week. |
Generally It Is not the best solution to benchmark in the loop with predict: https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call For controlling the XLA compilation see my prpoposal at: |
I wonder if we should @tf.function our call methods by default to: a.) mute warnings |
vectorized_map
causes tf.function
retracing and lowers performance with tf.function
.vectorized_map
causes tf.function
retracing.
@sebastian-sz Can you try your initial example with the last tf-nightly version? |
@bhack Running with I am however happy with the performance from |
Vectorized map is going internally to trace the function of we are in the default eager mode but model are by default tf.function wrapped. If we want to maintain the critical section eager-compatible we need to automate the conditonal call on standard map_fn in the base class overload we have done (we are in eager mode). |
More in general I think that this use of "layer as op" is still a little bit confusing: |
I don’t want to close it yet because I feel we can need to figure out how to effectively communicate this recommendation to users 🤔 |
We could also @tf.function the base layers call method if needed. Or the augment batch method |
It really depends.. do you want to silently be in graph mode with some functions? As the end user/developer doesn't control the vectorization in the API it is something that you are going to do behind the scene without any notification. At least
|
/cc @mdanatg |
I think we now have better mechanisms to protect against excessive retracing. Is the error coming from a standard Keras layer, or is it a custom one? |
As we don't have an object with So I still believe that is better to automatically call |
Cause when the graph creation it is done implicitly by API design like in https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=en#map
|
Ok, after lots more digging and time I agree with you @bhack we should apply map_fn in eager, vectorized in graph. We can tackle this after @divyashreepathihalli migrates BaseImageAugmentationLayer to KerasCV. It will be easier to update when in KerasCV |
Problem description
It seems like applying some layers that use
BaseImageAugmentationLayer
andself.auto_vectorize=True
, over batched input are causingtf.function
retracing:raises
Benchmarks
Running simple benchmarks confirms performance degradation with
tf.function
and batched input:Case 1: auto_vectorize=True
Without
tf.function
0.067 ms.With: 0.079 ms.
Case 2: auto_vectorize=False
The issue doesn't pop up with non-batched input e. g.
(224, 224, 3)
or if one changesself.auto_vectorize=False
in the layer.Setting
self.auto_vectorize=False
will yield:Withouth
tf.function
: 0.017 msWith: 0.013 ms.
Case 3: override
_batch_augment
(if possible)In case of vectorized operations, the fastest option is still overriding
_batch_augment
to returnself._augment(inputs)
. This will yield:Without
tf.function
: 0.0059 msWith: 0.0016 ms
The text was updated successfully, but these errors were encountered: