Skip to content

Commit

Permalink
Merge pull request #494 from bknueven/sync_locs
Browse files Browse the repository at this point in the history
add more granular callouts to PH
  • Loading branch information
bknueven authored Feb 28, 2025
2 parents 0a2d4d4 + 79af1ec commit 6d1f13f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
20 changes: 20 additions & 0 deletions mpisppy/cylinders/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,26 @@ def sync(self):
def sync_with_spokes(self):
self.sync()

def sync_bounds(self):
if self.has_outerbound_spokes:
self.receive_outerbounds()
if self.has_innerbound_spokes:
self.receive_innerbounds()
if self.has_bounds_only_spokes:
self.send_boundsout()

def sync_extensions(self):
if self.opt.extensions is not None:
self.opt.extobject.sync_with_spokes()

def sync_nonants(self):
if self.has_nonant_spokes:
self.send_nonants()

def sync_Ws(self):
if self.has_w_spokes:
self.send_ws()

def is_converged(self):
if self.opt.best_bound_obj_val is not None:
self.BestOuterBound = self.OuterBoundUpdate(self.opt.best_bound_obj_val)
Expand Down
18 changes: 14 additions & 4 deletions mpisppy/phbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,10 @@ def _vb(msg):
if self._can_update_best_bound():
self.best_bound_obj_val = self.trivial_bound

if self.spcomm is not None:
if hasattr(self.spcomm, "sync_nonants"):
self.spcomm.sync_nonants()
self.spcomm.sync_extensions()
elif hasattr(self.spcomm, "sync"):
self.spcomm.sync()

if have_extensions:
Expand Down Expand Up @@ -1000,7 +1003,7 @@ def iterk_loop(self):
self.conv = None

max_iterations = int(self.options["PHIterLimit"])
if self.spcomm is not None:
if hasattr(self.spcomm, "is_converged"):
# print a screen trace for iteration 0
if self.spcomm.is_converged():
global_toc("Cylinder convergence", self.cylinder_rank == 0)
Expand All @@ -1020,6 +1023,9 @@ def iterk_loop(self):
self.Update_W(verbose)
#global_toc('Rank: {} - After Update_W'.format(self.cylinder_rank), True)

if hasattr(self.spcomm, "sync_Ws"):
self.spcomm.sync_Ws()

if smoothed:
self.Update_z(verbose)

Expand Down Expand Up @@ -1064,11 +1070,15 @@ def iterk_loop(self):
if have_extensions:
self.extobject.enditer()

if self.spcomm is not None:
self.spcomm.sync()
if hasattr(self.spcomm, "sync_nonants"):
self.spcomm.sync_nonants()
self.spcomm.sync_bounds()
self.spcomm.sync_extensions()
if self.spcomm.is_converged():
global_toc("Cylinder convergence", self.cylinder_rank == 0)
break
elif hasattr(self.spcomm, "sync"):
self.spcomm.sync()

if have_extensions:
self.extobject.enditer_after_sync()
Expand Down

0 comments on commit 6d1f13f

Please sign in to comment.