Skip to content

Commit

Permalink
Add steps_per_execution to jax backend (keras-team#165)
Browse files Browse the repository at this point in the history
* Add `steps_per_execution` to jax backend

* update code

* special case funcs

* add docstring

* simplify code

---------

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
  • Loading branch information
haifeng-jin and haifeng-jin authored May 14, 2023
1 parent 00b28ec commit e410249
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 33 deletions.
85 changes: 63 additions & 22 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def fit(
steps_per_epoch=steps_per_epoch,
shuffle=shuffle,
class_weight=class_weight,
steps_per_execution=self.steps_per_execution,
)

compile_metrics_unbuilt = (
Expand All @@ -83,6 +84,7 @@ def fit(
if not self.built or compile_metrics_unbuilt:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data = data[0]
(
x,
y,
Expand Down Expand Up @@ -168,14 +170,21 @@ def _train_step(state, data):
)
return logs, state

def _train_multi_step(state, data):
for single_step_data in data:
logs, state = _train_step(state, single_step_data)
return logs, state

if not self.run_eagerly and self.jit_compile:

@jax.jit
def train_step(state, data):
return _train_step(state, data)
if self.steps_per_execution > 1:
return _train_multi_step(state, data)
return _train_step(state, data[0])

else:
train_step = _train_step
train_step = _train_multi_step

self.stop_training = False
callbacks.on_train_begin()
Expand Down Expand Up @@ -239,6 +248,7 @@ def train_step(state, data):
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_execution=self.steps_per_execution,
)
val_logs = self.evaluate(
x=val_x,
Expand Down Expand Up @@ -300,11 +310,13 @@ def evaluate(
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data = data[0]
(
x,
y,
Expand Down Expand Up @@ -374,14 +386,21 @@ def _test_step(state, data):
)
return logs, state

def _test_multi_step(state, data):
for single_step_data in data:
logs, state = _test_step(state, single_step_data)
return logs, state

if not self.run_eagerly and self.jit_compile:

@jax.jit
def test_step(state, data):
return _test_step(state, data)
if self.steps_per_execution > 1:
return _test_multi_step(state, data)
return _test_step(state, data[0])

else:
test_step = _test_step
test_step = _test_multi_step

callbacks.on_test_begin()
logs = None
Expand Down Expand Up @@ -430,13 +449,14 @@ def predict(
batch_size=batch_size,
steps_per_epoch=steps,
shuffle=False,
steps_per_execution=self.steps_per_execution,
)

if not self.built:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Build model
self(data)
self(data[0])
break

# Container that configures and calls callbacks.
Expand All @@ -451,18 +471,36 @@ def predict(
model=self,
)

def _predict_multi_step(
trainable_variables, non_trainable_variables, data
):
return [
self.stateless_call(
trainable_variables,
non_trainable_variables,
single_step_data,
)
for single_step_data in data
]

if not self.run_eagerly and self.jit_compile:

@jax.jit
def predict_step(
trainable_variables, non_trainable_variables, data
):
return self.stateless_call(
trainable_variables, non_trainable_variables, data
)
if self.steps_per_execution > 1:
return _predict_multi_step(
trainable_variables, non_trainable_variables, data
)
return [
self.stateless_call(
trainable_variables, non_trainable_variables, data[0]
)
]

else:
predict_step = self.stateless_call
predict_step = _predict_multi_step

callbacks.on_predict_begin()

Expand All @@ -471,21 +509,24 @@ def predict_step(
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
batch_outputs, non_trainable_variables = predict_step(
multi_step_return_values = predict_step(
trainable_variables, non_trainable_variables, x
)
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(batch_output),
outputs,
batch_outputs,
)
for batch_outputs, _ in multi_step_return_values:
if outputs is None:
outputs = tf.nest.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
else:
tf.__internal__.nest.map_structure_up_to(
batch_outputs,
lambda output, batch_output: output.append(
batch_output
),
outputs,
batch_outputs,
)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
callbacks.on_predict_end()
return tf.__internal__.nest.map_structure_up_to(
Expand Down
16 changes: 14 additions & 2 deletions keras_core/trainers/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,21 @@ def _get_iterator(self, return_type):
return iterator

def enumerate_epoch(self, return_type="np"):
buffer = []
if self.steps_per_epoch:
if not self._current_iterator:
self._current_iterator = self._get_iterator(return_type)
self._insufficient_data = False

for step in range(self.steps_per_epoch):
if self._insufficient_data:
break
try:
data = next(self._current_iterator)
yield step, data
buffer.append(data)
if len(buffer) == self.steps_per_execution:
yield step - len(buffer) + 1, buffer
buffer = []
except (StopIteration, tf.errors.OutOfRangeError):
warnings.warn(
"Your input ran out of data; interrupting epoch. "
Expand All @@ -163,9 +168,16 @@ def enumerate_epoch(self, return_type="np"):
)
self._current_iterator = None
self._insufficient_data = True
if buffer:
yield step - len(buffer) + 1, buffer
else:
for step, data in enumerate(self._get_iterator(return_type)):
yield step, data
buffer.append(data)
if len(buffer) == self.steps_per_execution:
yield step - len(buffer) + 1, buffer
buffer = []
if buffer:
yield step - len(buffer) + 1, buffer
if not self._num_batches:
# Infer the number of batches returned by the data_adater.
# Assumed static.
Expand Down
1 change: 1 addition & 0 deletions keras_core/trainers/epoch_iterator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _test_basic_flow(self, return_type):
)
steps_seen = []
for step, batch in iterator.enumerate_epoch(return_type=return_type):
batch = batch[0]
steps_seen.append(step)
self.assertEqual(len(batch), 3)
if return_type == "np":
Expand Down
1 change: 1 addition & 0 deletions keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
self._run_eagerly = False
self._jit_compile = True
self.compiled = False
self.steps_per_execution = 1

@tracking.no_automatic_dependency_tracking
def compile(
Expand Down
16 changes: 7 additions & 9 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import pytest

from keras_core import backend
from keras_core import initializers
Expand Down Expand Up @@ -216,11 +215,6 @@ def test_predict_flow_graph_fn(self):
def test_predict_flow_jit(self):
self._test_predict_flow(run_eagerly=False, jit_compile=True)

# TODO: Remove the skipif when implemented steps_per_execution for JAX.
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="JAX does not support steps_per_execution yet",
)
def test_steps_per_execution_steps_count(self):
class StepCount(Callback):
def __init__(self):
Expand All @@ -235,12 +229,16 @@ def on_batch_begin(self, batch, logs=None):
x = np.ones((100, 4))
y = np.ones((100, 1))
model = ExampleModel(units=1)
model.compile(loss="mse", optimizer="adam", steps_per_execution=3)
model.compile(
loss="mse",
optimizer="adam",
steps_per_execution=3,
)
step_count = StepCount()
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count])
model.fit(x=x, y=y, batch_size=16, callbacks=[step_count], verbose=0)
self.assertEqual(step_count.count, 3)

model_2 = ExampleModel(units=1)
model_2.compile(loss="mse", optimizer="adam", steps_per_execution=1)
model_2.fit(x=x, y=y, batch_size=16)
model_2.fit(x=x, y=y, batch_size=16, verbose=0)
self.assertAllClose(model.get_weights(), model_2.get_weights())

0 comments on commit e410249

Please sign in to comment.