Skip to content

Commit

Permalink
Merge pull request #471 from bknueven/xhatib-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
bknueven authored Jan 25, 2025
2 parents 3ed7360 + e560a4c commit 7052cbc
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 145 deletions.
54 changes: 54 additions & 0 deletions mpisppy/cylinders/xhatbase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
###############################################################################
# mpi-sppy: MPI-based Stochastic Programming in PYthon
#
# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for
# Sustainable Energy, LLC, The Regents of the University of California, et al.
# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for
# full copyright and license information.
###############################################################################

import abc
import mpisppy.cylinders.spoke as spoke
from mpisppy.utils.xhat_eval import Xhat_Eval

class XhatInnerBoundBase(spoke.InnerBoundNonantSpoke):

@abc.abstractmethod
def xhat_extension(self):
raise NotImplementedError


def xhat_prep(self):
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

## for later
self.verbose = self.opt.options["verbose"] # typing aid

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError(f"{self.__class__.__name__} must be used with Xhat_Eval.")

xhatter = self.xhat_extension()

### begin iter0 stuff
xhatter.pre_iter0()
if self.opt.extensions is not None:
self.opt.extobject.pre_iter0() # for an extension
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\
f"(E1_tolerance is {self.opt.E1_tolerance})")
### end iter0 stuff (but note: no need for iter 0 solves in an xhatter)

xhatter.post_iter0()
if self.opt.extensions is not None:
self.opt.extobject.post_iter0() # for an extension

self.opt._save_nonants() # make the cache

return xhatter
38 changes: 5 additions & 33 deletions mpisppy/cylinders/xhatlooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
# full copyright and license information.
###############################################################################
# updated April 2020
import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatlooper import XhatLooper
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase
import logging
import mpisppy.log

Expand All @@ -20,44 +19,17 @@
logger = logging.getLogger("mpisppy.cylinders.xhatlooper_bounder")


class XhatLooperInnerBound(spoke.InnerBoundNonantSpoke):
class XhatLooperInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'X'

def xhatlooper_prep(self):
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatLooper(self.opt)

### begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
if self.opt.cylinder_rank == 0:
print("ERROR")
print("Total probability of scenarios was ", self.opt.E1)
print("E1_tolerance = ", self.opt.E1_tolerance)
quit()
### end iter0 stuff

xhatter.post_iter0()
self.opt._save_nonants() # make the cache

return xhatter
def xhat_extension(self):
return XhatLooper(self.opt)

def main(self):
logger.debug(f"Entering main on xhatlooper spoke rank {self.global_rank}")

xhatter = self.xhatlooper_prep()
xhatter = self.xhat_prep()

scen_limit = self.opt.options['xhat_looper_options']['scen_limit']

Expand Down
42 changes: 8 additions & 34 deletions mpisppy/cylinders/xhatshufflelooper_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,52 +9,25 @@
import logging
import random
import mpisppy.log
import mpisppy.cylinders.spoke as spoke

from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.extensions.xhatbase import XhatBase
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

# Could also pass, e.g., sys.stdout instead of a filename
mpisppy.log.setup_logger("mpisppy.cylinders.xhatshufflelooper_bounder",
"xhatclp.log",
level=logging.CRITICAL)
logger = logging.getLogger("mpisppy.cylinders.xhatshufflelooper_bounder")

class XhatShuffleInnerBound(spoke.InnerBoundNonantSpoke):
class XhatShuffleInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'X'

def xhatbase_prep(self):
def xhat_extension(self):
return XhatBase(self.opt)

if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

## for later
self.verbose = self.opt.options["verbose"] # typing aid
self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"]

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatBase(self.opt)
self.xhatter = xhatter

### begin iter0 stuff
xhatter.pre_iter0() # for an extension
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if abs(1 - self.opt.E1) > self.opt.E1_tolerance:
raise ValueError(f"Total probability of scenarios was {self.opt.E1} "+\
f"(E1_tolerance is {self.opt.E1_tolerance})")
### end iter0 stuff (but note: no need for iter 0 solves in an xhatter)

xhatter.post_iter0()

self.opt._save_nonants() # make the cache
def xhat_prep(self):
self.xhatter = super().xhat_prep()

## option drive this? (could be dangerous)
self.random_seed = 42
Expand Down Expand Up @@ -91,7 +64,7 @@ def _vb(msg):
def main(self):
logger.debug(f"Entering main on xhatshuffle spoke rank {self.global_rank}")

self.xhatbase_prep()
self.xhat_prep()
if "reverse" in self.opt.options["xhat_looper_options"]:
self.reverse = self.opt.options["xhat_looper_options"]["reverse"]
else:
Expand All @@ -100,6 +73,7 @@ def main(self):
self.iter_step = self.opt.options["xhat_looper_options"]["iter_step"]
else:
self.iter_step = None
self.solver_options = self.opt.options["xhat_looper_options"]["xhat_solver_options"]

# give all ranks the same seed
self.random_stream.seed(self.random_seed)
Expand Down
46 changes: 5 additions & 41 deletions mpisppy/cylinders/xhatspecific_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
# udpated April 20
# specific xhat supplied (copied from xhatlooper_bounder by DLW, Dec 2019)

import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatspecific import XhatSpecific
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

import mpisppy.MPI as mpi
import logging
Expand All @@ -22,47 +21,12 @@


############################################################################
class XhatSpecificInnerBound(spoke.InnerBoundNonantSpoke):
class XhatSpecificInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'S'

def ib_prep(self):
"""
Set up the objects needed for bounding.
Returns:
xhatter (xhatspecific object): Constructed by a call to Prep
"""
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatShuffleInnerBound must be used with Xhat_Eval.")

xhatter = XhatSpecific(self.opt)
# somehow deal with the prox option .... TBD .... important for aph APH

# begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if (abs(1 - self.opt.E1) > self.opt.E1_tolerance):
if self.opt.cylinder_rank == 0:
print("ERROR")
print("Total probability of scenarios was ", self.opt.E1)
print("E1_tolerance = ", self.opt.E1_tolerance)
quit()

### end iter0 stuff

xhatter.post_iter0()
self.opt._save_nonants() # make the cache

return xhatter
def xhat_extension(self):
return XhatSpecific(self.opt)

def main(self):
"""
Expand All @@ -76,7 +40,7 @@ def main(self):
xhat_scenario_dict = self.opt.options["xhat_specific_options"]\
["xhat_scenario_dict"]

xhatter = self.ib_prep()
xhatter = self.xhat_prep()

ib_iter = 1 # ib is for inner bound
while (not self.got_kill_signal()):
Expand Down
38 changes: 8 additions & 30 deletions mpisppy/cylinders/xhatxbar_bounder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
# xbar from xhat (copied from xhat specific, DLW Feb 2023)

import pyomo.environ as pyo
import mpisppy.cylinders.spoke as spoke
from mpisppy.extensions.xhatxbar import XhatXbar
from mpisppy.utils.xhat_eval import Xhat_Eval
from mpisppy.cylinders.xhatbase import XhatInnerBoundBase

import mpisppy.MPI as mpi
import logging
Expand All @@ -34,43 +33,22 @@ def _attach_xbars(opt):


############################################################################
class XhatXbarInnerBound(spoke.InnerBoundNonantSpoke):
class XhatXbarInnerBound(XhatInnerBoundBase):

converger_spoke_char = 'B'

def ib_prep(self):
def xhat_extension(self):
return XhatXbar(self.opt)

def xhat_prep(self):
"""
Set up the objects needed for bounding.
Returns:
xhatter (xhatxbar object): Constructed by a call to Prep
"""
if "bundles_per_rank" in self.opt.options\
and self.opt.options["bundles_per_rank"] != 0:
raise RuntimeError("xhat spokes cannot have bundles (yet)")

if not isinstance(self.opt, Xhat_Eval):
raise RuntimeError("XhatXbarInnerBound must be used with Xhat_Eval.")

xhatter = XhatXbar(self.opt)
# somehow deal with the prox option .... TBD .... important for aph APH

# begin iter0 stuff
xhatter.pre_iter0()
self.opt._save_original_nonants()

self.opt._lazy_create_solvers() # no iter0 loop, but we need the solvers

self.opt._update_E1()
if (abs(1 - self.opt.E1) > self.opt.E1_tolerance):
raise RuntimeError(f"Total probability of scenarios was {self.E1}; E1_tolerance = ", self.E1_tolerance)

### end iter0 stuff

xhatter.post_iter0()
xhatter = super().xhat_prep()
_attach_xbars(self.opt)
self.opt._save_nonants() # make the cache

return xhatter

def main(self):
Expand All @@ -81,7 +59,7 @@ def main(self):
dtm = logging.getLogger(f'dtm{global_rank}')
logging.debug("Enter xhatxbar main on rank {}".format(global_rank))

xhatter = self.ib_prep()
xhatter = self.xhat_prep()

ib_iter = 1 # ib is for inner bound
while (not self.got_kill_signal()):
Expand Down
Loading

0 comments on commit 7052cbc

Please sign in to comment.