Skip to content

Commit

Permalink
fix: use compositonal implementation for torch backend ivy.put_along_…
Browse files Browse the repository at this point in the history
…axis
  • Loading branch information
Sam-Armstrong committed Jul 10, 2024
1 parent 1d16dfd commit e4539a2
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions ivy/functional/backends/torch/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,46 +443,6 @@ def column_stack(
return torch.column_stack(arrays)


@with_supported_dtypes({"2.2 and below": ("float32", "float64")}, backend_version)
def put_along_axis(
arr: torch.Tensor,
indices: torch.Tensor,
values: Union[int, torch.Tensor],
axis: int,
/,
*,
mode: Literal["sum", "min", "max", "mul", "replace"] = "replace",
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
arr, values = ivy.promote_types_of_inputs(arr, values)

mode_mappings = {
"sum": "sum",
"min": "amin",
"max": "amax",
"mul": "prod",
"replace": "replace",
}
mode = mode_mappings.get(mode, mode)
indices = indices.to(torch.int64)
if isinstance(values, torch.Tensor) and values.dim() == 0:
values = values.item()
if mode == "replace":
return torch.scatter(arr, axis, indices, values, out=out)
else:
return torch.scatter_reduce(arr, axis, indices, values, reduce=mode, out=out)


put_along_axis.partial_mixed_handler = lambda *args, mode=None, **kwargs: mode in [
"replace",
"sum",
"mul",
"mean",
"max",
"min",
]


def concat_from_sequence(
input_sequence: Union[Tuple[torch.Tensor], List[torch.Tensor]],
/,
Expand Down

0 comments on commit e4539a2

Please sign in to comment.