Skip to content

Commit

Permalink
Reving unit tests for parallel process param updating (#1514)
Browse files Browse the repository at this point in the history
  • Loading branch information
john-science authored Dec 4, 2023
1 parent 457ca15 commit 4f55d94
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 198 deletions.
5 changes: 2 additions & 3 deletions armi/reactor/composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2922,9 +2922,8 @@ def syncMpiState(self):
runLog.error("\n".join(msg))
raise

errors = collections.defaultdict(
list
) # key is (comp, paramName) value is conflicting nodes
# key is (comp, paramName) value is conflicting nodes
errors = collections.defaultdict(list)
syncCount = 0
compsPerNode = {len(nodeSyncData) for nodeSyncData in allSyncData}

Expand Down
253 changes: 58 additions & 195 deletions armi/reactor/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests of the Parameters class."""
from distutils.spawn import find_executable
import copy
import traceback
import unittest

from armi import context
from armi.reactor import composites
from armi.reactor import parameters

# determine if this is a parallel run, and MPI is installed
MPI_EXE = None
if find_executable("mpiexec.exe") is not None:
MPI_EXE = "mpiexec.exe"
elif find_executable("mpiexec") is not None:
MPI_EXE = "mpiexec"


class MockComposite:
def __init__(self, name):
Expand Down Expand Up @@ -213,6 +220,13 @@ class Mock(parameters.ParameterCollection):
self.assertEqual("encapsulated", mock.noSetter)

def test_setter(self):
"""Test the Parameter setter() tooling, that signifies if a Parameter has been updated.
.. test:: Tooling that allows a Parameter to signal it needs to be updated across processes.
:id: T_ARMI_PARAM_PARALLEL0
:tests: R_ARMI_PARAM_PARALLEL
"""

class Mock(parameters.ParameterCollection):
pDefs = parameters.ParameterDefinitionCollection()
with pDefs.createBuilder() as pb:
Expand Down Expand Up @@ -241,6 +255,7 @@ def nPlus1(self, value):
print(mock.n)
with self.assertRaises(parameters.ParameterError):
print(mock.nPlus1)

mock.n = 15
self.assertEqual(15, mock.n)
self.assertEqual(16, mock.nPlus1)
Expand All @@ -251,6 +266,13 @@ def nPlus1(self, value):
self.assertTrue(all(pd.assigned for pd in mock.paramDefs))

def test_setterGetterBasics(self):
"""Test the Parameter setter/getter tooling, through the lifecycle of a Parameter being updated.
.. test:: Tooling that allows a Parameter to signal it needs to be updated across processes.
:id: T_ARMI_PARAM_PARALLEL1
:tests: R_ARMI_PARAM_PARALLEL
"""

class Mock(parameters.ParameterCollection):
pDefs = parameters.ParameterDefinitionCollection()
with pDefs.createBuilder() as pb:
Expand Down Expand Up @@ -445,9 +467,7 @@ class MockPC(parameters.ParameterCollection):
self.assertEqual(set(pc.paramDefs.inCategory("bacon")), set([p2, p3]))

def test_parameterCollectionsHave__slots__(self):
"""Make sure something is implemented to prevent accidental creation of
attributes.
"""
"""Tests we prevent accidental creation of attributes."""
self.assertEqual(
set(["_hist", "_backup", "assigned", "_p_serialNum", "serialNum"]),
set(parameters.ParameterCollection._slots),
Expand All @@ -457,12 +477,6 @@ class MockPC(parameters.ParameterCollection):
pass

pc = MockPC()
# No longer protecting against __dict__ access. If someone REALLY wants to
# staple something to a parameter collection with no guarantees of anything,
# that's on them
# with self.assertRaises(AttributeError):
# pc.__dict__["foo"] = 5

with self.assertRaises(AssertionError):
pc.whatever = 22

Expand Down Expand Up @@ -507,230 +521,79 @@ class MockSyncPC(parameters.ParameterCollection):


def makeComp(name):
"""Helper method for MPI sync tests: mock up a Composite with a minimal param collections."""
c = composites.Composite(name)
c.p = MockSyncPC()
return c


class SynchronizationTests:
"""Some unit tests that must be run with mpirun instead of the standard unittest
system.
"""
class SynchronizationTests(unittest.TestCase):
"""Some tests that must be run with mpirun instead of the standard unittest system."""

def setUp(self):
self.r = makeComp("reactor")
self.r.core = makeComp("core")
self.r.add(self.r.core)
for ai in range(context.MPI_SIZE * 4):
for ai in range(context.MPI_SIZE * 3):
a = makeComp("assembly{}".format(ai))
self.r.core.add(a)
for bi in range(10):
for bi in range(3):
a.add(makeComp("block{}-{}".format(ai, bi)))
self.comps = [self.r.core] + self.r.core.getChildren(deep=True)
for pd in MockSyncPC().paramDefs:
pd.assigned = parameters.NEVER

def tearDown(self):
del self.r

def run(self, testNamePrefix="mpitest_"):
with open("mpitest{}.temp".format(context.MPI_RANK), "w") as self.l:
for methodName in sorted(dir(self)):
if methodName.startswith(testNamePrefix):
self.write("{}.{}".format(self.__class__.__name__, methodName))
try:
self.setUp()
getattr(self, methodName)()
except Exception:
self.write("failed, big time")
traceback.print_exc(file=self.l)
self.write("*** printed exception")
try:
self.tearDown()
except: # noqa: bare-except
pass

self.l.write("done.")

def write(self, msg):
self.l.write("{}\n".format(msg))
self.l.flush()

def assertRaises(self, exceptionType):
class ExceptionCatcher:
def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
if exc_type is exceptionType:
return True
raise AssertionError(
"Expected {}, but got {}".format(exceptionType, exc_type)
)

return ExceptionCatcher()
self.comps = [self.r.core] + self.r.core.getChildren(deep=True)

def assertEqual(self, expected, actual):
if expected != actual:
raise AssertionError(
"(expected) {} != {} (actual)".format(expected, actual)
)
@unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only")
def test_noConflicts(self):
"""Make sure sync works across processes.
def assertNotEqual(self, expected, actual):
if expected == actual:
raise AssertionError(
"(expected) {} == {} (actual)".format(expected, actual)
)
.. test:: Synchronize a reactor's state across processes.
:id: T_ARMI_CMP_MPI0
:tests: R_ARMI_CMP_MPI
"""
_syncCount = self.r.syncMpiState()

def mpitest_noConflicts(self):
for ci, comp in enumerate(self.comps):
if ci % context.MPI_SIZE == context.MPI_RANK:
comp.p.param1 = (context.MPI_RANK + 1) * 30.0
else:
self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param1)

self.assertEqual(len(self.comps), self.r.syncMpiState())
syncCount = self.r.syncMpiState()
self.assertEqual(len(self.comps), syncCount)

for ci, comp in enumerate(self.comps):
self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param1)

def mpitest_noConflicts_setByString(self):
"""Make sure params set by string also work with sync."""
for ci, comp in enumerate(self.comps):
if ci % context.MPI_SIZE == context.MPI_RANK:
comp.p.param2 = (context.MPI_RANK + 1) * 30.0
else:
self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param2)

self.assertEqual(len(self.comps), self.r.syncMpiState())

for ci, comp in enumerate(self.comps):
self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param2)
@unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only")
def test_withConflicts(self):
"""Test conflicts arise correctly if we force a conflict.
def mpitest_withConflicts(self):
.. test:: Raise errors when there are conflicts across processes.
:id: T_ARMI_CMP_MPI1
:tests: R_ARMI_CMP_MPI
"""
self.r.core.p.param1 = (context.MPI_RANK + 1) * 99.0
with self.assertRaises(ValueError):
self.r.syncMpiState()

def mpitest_withConflictsButSameValue(self):
@unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only")
def test_withConflictsButSameValue(self):
"""Test that conflicts are ignored if the values are the same.
.. test:: Don't raise errors when multiple processes make the same changes.
:id: T_ARMI_CMP_MPI2
:tests: R_ARMI_CMP_MPI
"""
self.r.core.p.param1 = (context.MPI_SIZE + 1) * 99.0
self.r.syncMpiState()
self.assertEqual((context.MPI_SIZE + 1) * 99.0, self.r.core.p.param1)

def mpitest_noConflictsMaintainWithStateRetainer(self):
assigned = []
with self.r.retainState(parameters.inCategory("cat1")):
for ci, comp in enumerate(self.comps):
comp.p.param2 = 99 * ci
if ci % context.MPI_SIZE == context.MPI_RANK:
comp.p.param1 = (context.MPI_RANK + 1) * 30.0
assigned.append(parameters.SINCE_ANYTHING)
else:
self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param1)
assigned.append(parameters.NEVER)

# 1st inside state retainer
self.assertEqual(
True, all(c.p.assigned == parameters.SINCE_ANYTHING for c in self.comps)
)

# confirm outside state retainer
self.assertEqual(assigned, [c.p.assigned for ci, c in enumerate(self.comps)])

# this rank's "assigned" components are not assigned on the workers, and so will
# be updated
self.assertEqual(len(self.comps), self.r.syncMpiState())

for ci, comp in enumerate(self.comps):
self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param1)

def mpitest_conflictsMaintainWithStateRetainer(self):
@unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only")
def test_conflictsMaintainWithStateRetainer(self):
"""Test that the state retainer fails correctly when it should."""
with self.r.retainState(parameters.inCategory("cat2")):
for _, comp in enumerate(self.comps):
comp.p.param2 = 99 * context.MPI_RANK

with self.assertRaises(ValueError):
self.r.syncMpiState()

def mpitest_rxCoeffsProcess(self):
"""This test mimics the process for rxCoeffs when doing distributed doppler."""

def do():
# we will do this over 4 passes (there are 4 * MPI_SIZE assemblies)
for passNum in range(4):
with self.r.retainState(parameters.inCategory("cat2")):
self.r.p.param3 = "hi"
for c in self.comps:
c.p.param1 = (
99 * context.MPI_RANK
) # this will get reset after state retainer
a = self.r.core[passNum * context.MPI_SIZE + context.MPI_RANK]
a.p.param2 = context.MPI_RANK * 20.0
for b in a:
b.p.param2 = context.MPI_RANK * 10.0

for ai, a2 in enumerate(self.r):
if ai % context.MPI_SIZE != context.MPI_RANK:
assert "param2" not in a2.p

self.assertEqual(parameters.SINCE_ANYTHING, param1.assigned)
self.assertEqual(parameters.SINCE_ANYTHING, param2.assigned)
self.assertEqual(parameters.SINCE_ANYTHING, param3.assigned)
self.assertEqual(parameters.SINCE_ANYTHING, a.p.assigned)

self.r.syncMpiState()

self.assertEqual(
parameters.SINCE_ANYTHING
& ~parameters.SINCE_LAST_DISTRIBUTE_STATE,
param1.assigned,
)
self.assertEqual(
parameters.SINCE_ANYTHING
& ~parameters.SINCE_LAST_DISTRIBUTE_STATE,
param2.assigned,
)
self.assertEqual(
parameters.SINCE_ANYTHING
& ~parameters.SINCE_LAST_DISTRIBUTE_STATE,
param3.assigned,
)
self.assertEqual(
parameters.SINCE_ANYTHING
& ~parameters.SINCE_LAST_DISTRIBUTE_STATE,
a.p.assigned,
)

self.assertEqual(parameters.NEVER, param1.assigned)
self.assertEqual(parameters.SINCE_ANYTHING, param2.assigned)
self.assertEqual(parameters.NEVER, param3.assigned)
self.assertEqual(parameters.SINCE_ANYTHING, a.p.assigned)
do_assert(passNum)

param1 = self.r.p.paramDefs["param1"]
param2 = self.r.p.paramDefs["param2"]
param3 = self.r.p.paramDefs["param3"]

def do_assert(passNum):
# ensure all assemblies and blocks set values for param2, but param1 is
# empty
for rank in range(context.MPI_SIZE):
a = self.r.core[passNum * context.MPI_SIZE + rank]
assert "param1" not in a.p
assert "param3" not in a.p
self.assertEqual(rank * 20, a.p.param2)
for b in a:
self.assertEqual(rank * 10, b.p.param2)
assert "param1" not in b.p
assert "param3" not in b.p

if context.MPI_RANK == 0:
with self.r.retainState(parameters.inCategory("cat2")):
context.MPI_COMM.bcast(self.r)
do()
[do_assert(passNum) for passNum in range(4)]
[do_assert(passNum) for passNum in range(4)]
else:
del self.r
self.r = context.MPI_COMM.bcast(None)
do()

0 comments on commit 4f55d94

Please sign in to comment.