Skip to content

Commit

Permalink
fix: getting device of KerasVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jul 19, 2024
1 parent ce5f797 commit 3805db6
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ivy/functional/backends/tensorflow/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def dev(
*,
as_native: bool = False,
) -> Union[ivy.Device, str]:
if "keras.src.backend.tensorflow.core.Variable" in str(x.__class__):
# Read the underlying tensor of a KerasVariable to get the device
x = x.value
if isinstance(x, tf.TensorArray):
# Read the underlying tensor being wrapped to get the device.
x = x.stack()
Expand Down

0 comments on commit 3805db6

Please sign in to comment.