Skip to content

Commit

Permalink
Merge branch 'main' into fix_indices_types
Browse files Browse the repository at this point in the history
  • Loading branch information
pspillai committed Apr 26, 2024
2 parents c30562a + b75e6f8 commit 4e6d766
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 35 deletions.
6 changes: 5 additions & 1 deletion ramba/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

#import llvmlite.binding as llvm
#llvm.set_option('','--debug-only=loop-vectorize')
#llvm.set_option('','--debug-only=loop-unroll')

import os
import sys
import numpy as np
Expand All @@ -21,7 +25,7 @@

distribute_min_size = 100
NUM_WORKERS_FOR_BCAST = 100
fastmath = False
fastmath = True

force_gpu = int(os.environ.get("RAMBA_FORCE_GPU", "-1"))

Expand Down
210 changes: 177 additions & 33 deletions ramba/ramba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,9 @@ def shape(self):
def getglobal(self):
return self.subspace

def get_partial_view(self, slice_index, shard, global_index=True, remap_view=True):
def get_partial_view(self, slice_index, shard=None, global_index=True, remap_view=True):
if shard is None:
shard = self.subspace
local_slice = (
shardview.slice_to_local(shard, slice_index)
if global_index
Expand Down Expand Up @@ -2165,7 +2167,7 @@ def get_view(self, gid, shard):
# slice_index may be globally indexed (same num dims as view), or locally indexed (same num dims as bcontainer)
# output dimensionality is always same as bcontainer
def get_partial_view(
self, gid, slice_index, shard, global_index=True, remap_view=True
self, gid, slice_index, shard=None, global_index=True, remap_view=True
):
lnd = self.numpy_map[gid]
#print("get_partial_view:", slice_index, shard, global_index, lnd)
Expand Down Expand Up @@ -3489,7 +3491,7 @@ def spmd(self, func, args):
sys.stdout.flush()

def run_deferred_ops(
self, uuid, arrays, delete_gids, pickledvars, exec_dist, fname, code, imports
self, uuid, arrays, delete_gids, pickledvars, exec_dist, fname, code, fname2, code2, imports
):
times = [timer()]
dprint(4, "HERE - deferredops; arrays:", arrays.keys())
Expand Down Expand Up @@ -3524,6 +3526,7 @@ def run_deferred_ops(
# check if function exists, else exec to create it
ldict = {}
gdict = globals()
if "min" in gdict: del(gdict["min"])
for imp in imports:
the_module = __import__(imp)
gdict[imp.split(".")[0]] = the_module
Expand All @@ -3534,6 +3537,11 @@ def run_deferred_ops(
ffunc = ramba_exec(fname, code, gdict, ldict)
gdict[fname] = FunctionMetadata(ffunc, [], {}, code=code)
func = gdict[fname]
if fname2 not in gdict:
dprint(2, "function does not exist, creating it\n", code2)
ffunc2 = ramba_exec(fname2, code2, gdict, ldict)
gdict[fname2] = FunctionMetadata(ffunc2, [], {})
func2 = gdict[fname2]
times.append(timer())

# Send data to other workers as needed, count how many items to recieve later
Expand All @@ -3558,7 +3566,7 @@ def run_deferred_ops(
subspace, info.distribution[self.worker_num]
): # whole compute range is local
sl = self.get_view(g, info.distribution[self.worker_num])
arr_parts[v].append((subspace, sl))
arr_parts[v].append((subspace, sl, True)) # True in idcates local
continue
overlap_time -= timer()
overlap_workers = shardview.get_overlaps(
Expand Down Expand Up @@ -3592,7 +3600,7 @@ def run_deferred_ops(
# arr_parts[v].append( (part, shardview.array_to_view(part, sl)) )
arrview_time -= timer()
sl = self.get_partial_view( g, shardview.to_slice(part), info.distribution[self.worker_num], remap_view=False,)
arr_parts[v].append( ( shardview.clean_range(part), shardview.array_to_view(part, sl),) )
arr_parts[v].append( ( shardview.clean_range(part), shardview.array_to_view(part, sl), True) ) # True indicates local
arrview_time += timer()
dprint(
2,
Expand Down Expand Up @@ -3680,7 +3688,7 @@ def run_deferred_ops(
x -= shardview.get_start(full_part)
sl2 = sl[ shardview.as_slice(base_part) ]
arr_parts[v].append(
(shardview.clean_range(part), shardview.array_to_view(part, sl2))
(shardview.clean_range(part), shardview.array_to_view(part, sl2), False) # False means not local
)
arrview_time += timer()

Expand All @@ -3689,25 +3697,57 @@ def run_deferred_ops(

vparts = numba.typed.List()
for _, pl in arr_parts.items():
for vpart, _ in pl:
for vpart, _, _ in pl:
vparts.append(vpart)
if len(vparts)==0:
ranges = []
else:
ranges = shardview.get_range_splits_list(vparts)
times.append(timer())
rangedvars = [{} for _ in ranges]
localranges = [True for _ in ranges]
for i, rpart in enumerate(ranges):
varlist = rangedvars[i]
for (varname, data) in manual_index_arrays.items():
varlist[varname] = data
for (varname, partlist) in arr_parts.items():
for (vpart, data) in partlist:
for (vpart, data, local) in partlist:
#print("TODD:", len(arr_parts), len(partlist), len(ranges))
if not shardview.overlaps(rpart, vpart): # skip
continue
tmp = shardview.mapsv(vpart, rpart)
varlist[varname] = shardview.get_base_slice(tmp, data)
localranges[i] = localranges[i] and local
for j, r in enumerate(ranges):
if not shardview.is_empty(r) and localranges[j]:
arrvars = {}
for i, (k, v) in enumerate(arrays.items()):
did_once=False
did_more_than_once=False
vtr = ""
unionview = None
for (vn, ai) in v[0]:
if (vn in manual_index_arrays or # don't mess with manual arrays
shardview.dist_has_neg_step(ai.distribution)): # use separate argument for views with negative steps
#arrvars[vn] = self.get_partial_view( k, shardview.to_slice(r), ai[1][self.worker_num] )
arrvars[vn] = rangedvars[j][vn]
else:
if not did_once:
vtr = vn
did_once=True
unionview = shardview.mapslice( ai.distribution[self.worker_num], shardview.to_slice(r) )
else:
view = shardview.mapslice( ai.distribution[self.worker_num], shardview.to_slice(r) )
unionview = shardview.view_union( view, unionview )
did_more_than_once=True
if did_more_than_once:
#print ("HERE: ",vtr, unionview, shardview.to_base_slice(unionview))
arrvars[vtr] = self.get_partial_view( k, shardview.to_base_slice(unionview)) #, ai[1][self.worker_num] )
elif did_once:
#print ("HERE: ",vtr, unionview, shardview.to_base_slice(unionview))
#arrvars[vtr] = self.get_partial_view( k, shardview.to_slice(r), ai[1][self.worker_num] )
arrvars[vtr] = rangedvars[j][vtr]
rangedvars[j] = arrvars

times.append(timer())
if len(ranges) > 1:
Expand All @@ -3716,16 +3756,28 @@ def run_deferred_ops(
total_elements = 0
# execute function in each range
for i, r in enumerate(ranges):
arrvars = rangedvars[i]
if not shardview.is_empty(r):
if ndebug>3:
for k, v in arrvars.items():
dprint(4, "inputs:", k, v, type(v))
for k, v in othervars.items():
dprint(4, "others:", k, v, type(v))

total_elements += sum([x.size for x in arrvars.values()])
func(shardview._index_start(r), self.worker_num, num_workers, **arrvars, **othervars)
arrvars = rangedvars[i]
#print ("EXECUTION: itershape ",shardview._size(r), " is local?", localranges[i])
if localranges[i]: # and False:
#print ("run: ")
#for k, v in arrvars.items():
# dprint(0, "inputs:", k, v, type(v))
#for k, v in othervars.items():
# dprint(0, "others:", k, v, type(v))
func2(shardview._index_start(r), tuple(shardview._size(r).astype(np.int64)),
self.worker_num, num_workers, **arrvars, **othervars)
#print ("run done : ")
else:
if ndebug>3:
for k, v in arrvars.items():
dprint(4, "inputs:", k, v, type(v))
for k, v in othervars.items():
dprint(4, "others:", k, v, type(v))

total_elements += sum([x.size for x in arrvars.values()])
func(shardview._index_start(r), tuple(shardview._size(r).astype(np.int64)),
self.worker_num, num_workers, **arrvars, **othervars)
if ndebug>3:
for k, v in arrvars.items():
dprint(4, "results:", k, v, type(v))
Expand Down Expand Up @@ -8056,6 +8108,10 @@ def get_temp_var(cls):
return cls.temp_var()

# Execute what we have now
# New version constructs two functions
# First one takes a separate argument for each view. This works for all edge cases, halos, etc.
# Second one has one argument per GID; this adjusts indices to account for different views.
# This should be faster, but works only for contiguous local array segments.
def execute(self):
#gc.collect() # Need to experiment if this is good or not.
times = [timer()]
Expand All @@ -8079,58 +8135,134 @@ def execute(self):
): # deep copy distribution, but keep reference same as other arrays may be pointing to the same one
d[i] = v
times.append(timer())
# substitute var_name with var_name[index] for ndarrays
index_text = "[index]"
if len(self.axis_reductions)>0:
index_text += "[axisindex]"
for (k, v) in live_gids.items():
for (vn, ai) in v[0]:
if ai.manual_idx: continue # skip for manually indexed arrays
for i in range(len(self.codelines)):
self.codelines[i] = self.codelines[i].replace(vn, vn + index_text)
times.append(timer())
# compile into function, using unpickled variables (non-ndarrays)
# Prepare args lists
precode = []
precode2 = []
args = []
args2 = []
args2tran = {}
args2uv = {}
offsets = {}
for i, (k, v) in enumerate(live_gids.items()):
count=0
vtr = ""
vnlist = []
unionview = None
# precode.append(" "+v+" = arrays["+str(i)+"]")
for (vn, ai) in v[0]:
args.append(vn)
vnlist.append(vn)
if shardview.dist_has_neg_step(ai.distribution): # use separate argument for views with negative steps
args2.append(vn)
else:
view = shardview.dist_to_view(ai.shape,ai.distribution,bdarray.gid_map[k].distribution)
if count==0:
args2.append(vn)
vtr = vn
unionview = view
offsets[vtr] = [[] for _ in range(shardview._base_offset(view).shape[0])]
else:
unionview = shardview.view_union( view, unionview )
args2tran[vn] = (vtr, view)
am = shardview._axis_map(view)
ub = shardview._base_offset(view)
st = shardview._steps(view)
for i,j in enumerate(am):
if j>=0:
offsets[vtr][j].append( (f"{vn}_{j}", f"{ub[j]}+global_start[{i}]*{st[i]}") )
for j in range(ub.shape[0]):
if j not in am:
offsets[vtr][j].append( (f"{vn}_{j}", f"{ub[j]}") )
count+=1
if count>1:
args2uv[vtr] = unionview
for l in offsets[vtr]:
if len(l)==0: continue
if len(l)==1:
precode2.append( l[0][0] + " = 0" )
continue
for i,j in l:
precode2.append( i + " = " + j )
precode2.append( "tmp = min(" + ",".join([i for i,_ in l]) +")" )
for i,_ in l:
precode2.append( i + " -= tmp" )

for i, (v, b) in enumerate(self.use_other.items()):
# precode.append(" "+v+" = vars["+str(i)+"]")
args.append(v)
args2.append(v)
# precode.append(" import numpy as np")
# precode.append("\n".join(self.imports))
times.append(timer())
# substitute var_name with var_name[index] for ndarrays
codelines2 = [v for v in self.codelines]
index_text = "[index]"
#if len(self.axis_reductions)>0:
# index_text += "[axisindex]"
def make_index_text( v, u, vn, a ):
vb = shardview._base_offset(v)
ub = shardview._base_offset(u)
am = shardview._axis_map(v)
st = shardview._steps(v)
#idx = [ str(i) for i in (vb-ub) ]
idx = [ f"{vn}_{i}" for i in range(vb.shape[0]) ]
for i,j in enumerate(am):
if j>=0:
idx[j] += "+index["+str(i)+"]"
if st[i] > 1:
idx[j] += "*"+str(st[i])
return "["+",".join(idx)+"]"
for (k, v) in live_gids.items():
for (vn, ai) in v[0]:
if ai.manual_idx: continue # skip for manually indexed arrays
vn0 = vn if vn not in args2tran else args2tran[vn][0]
index_text2 = index_text if vn0 not in args2uv else make_index_text( args2tran[vn][1], args2uv[vn0], vn, self.axis_reductions )
for i in range(len(self.codelines)):
self.codelines[i] = self.codelines[i].replace(vn, vn + index_text)
codelines2[i] = codelines2[i].replace(vn, vn0 + index_text2)
times.append(timer())
# compile into function, using unpickled variables (non-ndarrays)
if len(live_gids)==0:
dprint(0,"ramba deferred_op execute: Warning: nothing to do / everything out of scope")

if len(self.codelines)>0 and len(live_gids)>0:
an_array = list(live_gids.items())[0][1][0][0][0]
precode.append( " itershape = " + an_array + ".shape" )
#precode.append( " itershape = " + an_array + ".shape" )
if len(self.axis_reductions) > 0:
axis = self.axis_reductions[0][0]
if isinstance(axis, numbers.Integral):
axis = [axis]
tmp = "".join(["1," if i in axis else f"itershape[{i}]," for i in range(len(self.distribution[0][0]))])
precode.append( " itershape2 = ("+tmp+")" )
tmp = "".join([f"itershape[{i}]," for i in axis])
#tmp = "".join([f"itershape[{i}]," for i in axis])
tmp = "".join([f"itershape[{i}]," if i in axis else "1," for i in range(len(self.distribution[0][0]))])
precode.append( " itershape3 = ("+tmp+")" )
precode.append( " for pindex in numba.pndindex(itershape2):" )
tmp = "".join(["slice(None)," if i in axis else f"pindex[{i}]," for i in range(len(self.distribution[0][0]))])
precode.append( " index = ("+tmp+")" )
precode.append( " for axisindex in np.ndindex(itershape3):" )
#tmp = "".join(["slice(None)," if i in axis else f"pindex[{i}]," for i in range(len(self.distribution[0][0]))])
tmp = "".join([f"axisindex[{i}]," if i in axis else f"pindex[{i}]," for i in range(len(self.distribution[0][0]))])
precode.append( " index = ("+tmp+")" )
else:
precode.append( " for index in numba.pndindex(itershape):" )
code = ( ("" if len(self.precode)==0 else "\n " + "\n ".join(self.precode)) +
"\n" + "\n".join(precode) +
("" if len(self.codelines)==0 else "\n " + "\n ".join(self.codelines)) +
("" if len(self.postcode)==0 else "\n " + "\n ".join(self.postcode)) )
code2 = ( ("" if len(precode2)==0 else "\n " + "\n ".join(precode2)) +
("" if len(self.precode)==0 else "\n " + "\n ".join(self.precode)) +
"\n" + "\n".join(precode) +
("" if len(codelines2)==0 else "\n " + "\n ".join(codelines2)) +
("" if len(self.postcode)==0 else "\n " + "\n ".join(self.postcode)) )
else:
code = ""
code2 = ""
# Use hashlib here so that hash is same every time so that caching works.
code_hash = hashlib.sha256(code.encode('utf-8')).hexdigest()
fname = "ramba_deferred_ops_func_" + str(len(args)) + str(code_hash)
code = "def " + fname + "(global_start,worker_num,num_workers," + ",".join(args) + "):" + code + "\n pass"
code = "def " + fname + "(global_start,itershape,worker_num,num_workers," + ",".join(args) + "):" + code + "\n pass"
code_hash2 = hashlib.sha256(code2.encode('utf-8')).hexdigest()
fname2 = "ramba_deferred_ops_func_" + str(len(args2)) + str(code_hash2)
code2 = "def " + fname2 + "(global_start,itershape,worker_num,num_workers," + ",".join(args2) + "):" + code2 + "\n pass"
if (debug_showcode or ndebug>=2) and is_main_thread:
print("Executing code:\n" + code)
print("with")
Expand All @@ -8139,6 +8271,16 @@ def execute(self):
print (" ",t[0],t[1].shape,g)
for n,s in self.use_other.items():
print (" ",n,pickle.loads(s))
print (" itershape ", self.shape)
print()
print("Locally optimized code:\n" + code2)
print("with")
for g,l in live_gids.items():
t = l[0][0]
print (" ",t[0],shardview._size(args2uv[t[0]]) if t[0] in args2uv else t[1].shape,g)
for n,s in self.use_other.items():
print (" ",n,pickle.loads(s))
print (" itershape ", self.shape)
print()
times.append(timer())
remote_exec_all(
Expand All @@ -8150,6 +8292,8 @@ def execute(self):
self.distribution,
fname,
code,
fname2,
code2,
self.imports,
)
times.append(timer())
Expand Down
Loading

0 comments on commit 4e6d766

Please sign in to comment.