Skip to content

Commit

Permalink
Add workaround for scalar indices with get_at and torch
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jan 18, 2024
1 parent c449b76 commit 88006a7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion einx/backend/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ def minimum(a, b):
logsumexp = torch_.logsumexp

def get_at(tensor, coordinates):
return tensor[coordinates]
if coordinates[0].ndim == 0:
# Fix for https://github.com/pytorch/functorch/issues/747
# Scalar coordinates cause problems with torch.vmap and throw an error:
# "RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item() on a Tensor ..."
# As a workaround, we add a dummy dimension and remove it after the indexing operation.
return tensor[tuple(c[None] for c in coordinates)][0]
else:
return tensor[coordinates]
def set_at(tensor, coordinates, updates):
tensor[coordinates] = updates
return tensor
Expand Down

0 comments on commit 88006a7

Please sign in to comment.