Skip to content

Commit

Permalink
fix: tf backend execute_with_gradients when xs contains KerasVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jul 26, 2024
1 parent 6b454e6 commit ceb5735
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions ivy/functional/backends/tensorflow/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def execute_with_gradients(
_get_required_float_variables(xs, xs_grad_idxs)
)

# Conversion of KerasVariable to tf.Variable within xs_required container, so they can be watched
if ivy.is_ivy_container(xs_required):
ivy.nested_map(
lambda x: x._value
if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__)
else x,
xs_required,
include_derived=True,
)

# Creating a tape to record operations
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
tape.watch(xs_required)
Expand Down

0 comments on commit ceb5735

Please sign in to comment.