Skip to content

Commit

Permalink
fix (torch frontend)(tensor.py): fixing torch.Tensor.masked_scatter
Browse files Browse the repository at this point in the history
… where we were incorrectly broadcasting the input tensor as well.
  • Loading branch information
YushaArif99 committed Oct 10, 2024
1 parent e7c2e4d commit 11a1c1a
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,6 @@ def masked_select(self, mask):
return torch_frontend.masked_select(self, mask)

def masked_scatter(self, mask, source):
self = torch_frontend.broadcast_to(self, source.shape)
mask = torch_frontend.broadcast_to(mask, self.shape)
flat_self = torch_frontend.flatten(self.clone())
flat_mask = torch_frontend.flatten(mask)
Expand All @@ -1150,7 +1149,6 @@ def masked_scatter(self, mask, source):
return flat_self.reshape(self.shape)

def masked_scatter_(self, mask, source):
self = torch_frontend.broadcast_to(self, source.shape)
mask = torch_frontend.broadcast_to(mask, self.shape)
flat_self = torch_frontend.flatten(self.clone())
flat_mask = torch_frontend.flatten(mask)
Expand Down

0 comments on commit 11a1c1a

Please sign in to comment.