Skip to content

Commit

Permalink
BUGFIX: arrays created directly with ramba.ndarray() may not get crea…
Browse files Browse the repository at this point in the history
…ted correctly when their first use is as a view/slice
  • Loading branch information
pspillai committed Jan 10, 2024
1 parent fb4da87 commit 6e22b8d
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions ramba/ramba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3495,12 +3495,13 @@ def run_deferred_ops(
# info tuple is (size, distribution, local_border, from_border, to_border, dtype)
# TODO: Need to check array construction -- this code may break for views/slices; assumes first view of gid is the canonical, full array
# [ self.create_array(g, subspace, info[0], None, info[2], info[1], None if info[3] is None else info[3][self.worker_num], None if info[4] is None else info[4][self.worker_num]) for (g,(v,info,bdist,pad)) in arrays.items() if g not in self.numpy_map ]
for (g, (l, bdist, pad, _)) in arrays.items():
for (g, (l, _, bdist, pad, _)) in arrays.items():
if g in self.numpy_map:
continue
v = l[0][0]
info = l[0][1]
ss = shardview.clean_range(info.distribution[self.worker_num])
#ss = shardview.clean_range(info.distribution[self.worker_num])
ss = shardview.clean_range(bdist[self.worker_num])
self.create_array(
g,
ss,
Expand Down Expand Up @@ -3540,7 +3541,7 @@ def run_deferred_ops(
getview_time = 0.0
arrview_time = 0.0
copy_time = 0.0
for (g, (l, bdist, pad, _)) in arrays.items():
for (g, (l, _, bdist, pad, _)) in arrays.items():
for (v, info) in l:
arr_parts[v] = []
if info.manual_idx:
Expand Down Expand Up @@ -7763,7 +7764,7 @@ def __init__(self, shape, distribution, fdist):
self.write_arrs = [] # (gid,dist) list of write arrays; dist is None if flex_dist
self.use_gids = (
{}
) # map gid to tuple ([(tmp var name, array details)*], bdarray dist, pad)
) # map gid to tuple ([(tmp var name, array details)*], bdarray shape, bdarray dist, pad, flex)
self.preconstructed_gids = (
{}
) # subset of use_gids that already are remote_constructed
Expand All @@ -7784,10 +7785,10 @@ def get_var_name(self):
self.varcount += 1
return nm

def add_gid(self, gid, arr_info, bd_info, pad, flex_dist):
def add_gid(self, gid, arr_info, bd_shape, bd_info, pad, flex_dist):
if gid not in self.use_gids:
# self.use_gids[gid] = tuple([self.get_var_name(), arr_info, bd_info, pad])
self.use_gids[gid] = ([], bd_info, pad, flex_dist)
self.use_gids[gid] = ([], bd_shape, bd_info, pad, flex_dist)
self.keepalives.add(bdarray.get_by_gid(gid)) # keep bdarrays alive
for (v, ai) in self.use_gids[gid][0]:
if shardview.dist_is_eq(arr_info.distribution, ai.distribution):
Expand Down Expand Up @@ -7833,12 +7834,9 @@ def execute(self):
}
dprint(3, "use_gids:", self.use_gids.keys(), "\nlive_gids", live_gids.keys(), "\npreconstructed gids", self.preconstructed_gids, "\ndistribution", self.distribution)
# Change distributions for any flexible arrays -- we should not have slices here
for (_, (_, d, _, flex)) in live_gids.items():
if flex:
for (_, (_, s, d, _, flex)) in live_gids.items():
if flex and self.shape==s:
dcopy = shardview.clean_dist(self.distribution)
#dcopy = libcopy.deepcopy(shardview.clean_dist(self.distribution)) # version from dag branch
#dcopy = libcopy.deepcopy(self.distribution)
# d.clear()
for i, v in enumerate(
dcopy
): # deep copy distribution, but keep reference same as other arrays may be pointing to the same one
Expand Down Expand Up @@ -7979,6 +7977,7 @@ def __add_prepost( cls, oplist, post=False ):
oplist[1 + 2 * i] = cls.ramba_deferred_ops.add_gid(
x.gid,
x.get_details(),
bd.shape,
bd.distribution,
bd.pad,
bd.flex_dist and not bd.remote_constructed,
Expand Down Expand Up @@ -8110,6 +8109,7 @@ def add_op(
oplist[1 + 2 * i] = cls.ramba_deferred_ops.add_gid(
x.gid,
x.get_details(),
bd.shape,
bd.distribution,
bd.pad,
bd.flex_dist and not bd.remote_constructed,
Expand All @@ -8129,6 +8129,7 @@ def add_op(
v= cls.ramba_deferred_ops.add_gid(
v.gid,
v.get_details(),
bd.shape,
bd.distribution,
bd.pad,
bd.flex_dist and not bd.remote_constructed,
Expand Down

0 comments on commit 6e22b8d

Please sign in to comment.