Skip to content

Commit

Permalink
Fix bug brought by new changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Jan 18, 2024
1 parent f1e0139 commit d574f10
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,12 +644,22 @@ def new_ones(self, x, shape, dtype=torch.int64, layout=None, device='cpu', pin_m

def index_base(self, x, dim, index):
dim = [dim] if not isinstance(dim, list) else dim
if isinstance(index, list):
assert len(index) == 1
index = index[0]
assert dim[0] == 0
dim_op = self.get_proxy(
ascend_op.Const, (dim, torch.int32, [len(dim)]))
return self.get_proxy(ascend_op.GatherV2, (x, index, dim_op))

@register_conversion(torch.ops.aten.index.Tensor)
def index(self, x, index):
if isinstance(index, list):
return self.unsafe_index(x, index)
return self.index_base(x, 0, index)

@register_conversion(torch.ops.aten._unsafe_index.Tensor)
def unsafe_index(self, x, index):
if isinstance(index, list):
if len(index) == 1:
index = index[0]
Expand Down Expand Up @@ -724,7 +734,6 @@ def index(self, x, index):
if status > 0:
return self.get_proxy(ascend_op.Transpose, (gather, perm))
return gather

return self.index_base(x, 0, index)

@register_conversion(torch.ops.aten.index_select.default)
Expand Down

0 comments on commit d574f10

Please sign in to comment.