Skip to content

Commit

Permalink
BUGFIX: issues with mixed int/int32/int64 types for fancy indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
pspillai committed Jan 17, 2024
1 parent ad74213 commit 0373d4c
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions ramba/ramba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6089,7 +6089,6 @@ def setitem_array_executor(cls, temp_array, self, index, value):
# This is the advanced indexing case.
dim_sizes, preindex, postindind, arrayind, dist = dim_sizes_from_index(index, self.shape)
dst_arr = manual_idx_ndarray(self[preindex])
print(f"HERE: {dst_arr.shape} {dim_sizes} {dst_arr.distribution}")
if isinstance(value, np.ndarray):
value = array(value)
if isinstance(value, ndarray):
Expand Down Expand Up @@ -6119,11 +6118,11 @@ def setitem_array_executor(cls, temp_array, self, index, value):
if isinstance(i, numbers.Integral):
indop[-1]+=f"index[{i}]+global_start[{i}], "
elif isinstance(i, ndarray):
indop[-1]+=f"int("
indop[-1]+=f"np.int64("
indop.append(i)
indop.append("), ")
else:
indop[-1]+=f"int("
indop[-1]+=f"np.int64("
indop.append(i)
indop.append(f"[({arrindstr})]), ")
indop[-1]+=")"
Expand All @@ -6145,7 +6144,7 @@ def setitem_array_executor(cls, temp_array, self, index, value):
precode=["for i in range(num_workers): ", need_comm, "[0,i]=False"] )
comm = deferred_op.get_temp_var()
deferred_op.add_op([" ", comm, "[", nodeid, "].append(list(", ind_arr, "))"], valueind,
precode=["", comm,"=[[[", dst_arr,".shape[0]]]*0 for _ in range(num_workers)]"])
precode=["", comm,"=[[[np.int64(0)]]*0 for _ in range(num_workers)]"])
vals = deferred_op.get_temp_var()
deferred_op.add_op([" ", vals, "[", nodeid, "].append(", value, ")"], valueind,
precode=["", vals,"=[[", valuetype,"]*0 for _ in range(num_workers)]"])
Expand Down Expand Up @@ -6365,11 +6364,11 @@ def getitem_array_executor(cls, temp_array, self, index):
if isinstance(i, numbers.Integral):
indop[-1]+=f"index[{i}]+global_start[{i}], "
elif isinstance(i, ndarray):
indop[-1]+=f"int("
indop[-1]+=f"np.int64("
indop.append(i)
indop.append("), ")
else:
indop[-1]+=f"int("
indop[-1]+=f"np.int64("
indop.append(i)
indop.append(f"[({arrindstr})]), ")
indop[-1]+=")"
Expand All @@ -6391,7 +6390,7 @@ def getitem_array_executor(cls, temp_array, self, index):
precode=["for i in range(num_workers): ", need_comm, "[0,i]=False"] )
comm = deferred_op.get_temp_var()
deferred_op.add_op([" ", comm, "[", nodeid, "].append(list(", ind_arr, "+index))"], dst_arr,
precode=["", comm,"=[[[", src_arr,".shape[0]]]*0 for _ in range(num_workers)]"])
precode=["", comm,"=[[[np.int64(0)]]*0 for _ in range(num_workers)]"])
deferred_op.add_op(["#"], dst_arr, postcode=["for i in range(num_workers):"] )
tmp_arr = deferred_op.get_temp_var()
indextype = 'int64' if ramba_big_data else 'int32'
Expand Down

0 comments on commit 0373d4c

Please sign in to comment.