diff --git a/armi/reactor/composites.py b/armi/reactor/composites.py index de3831104..ead79f514 100644 --- a/armi/reactor/composites.py +++ b/armi/reactor/composites.py @@ -2922,9 +2922,8 @@ def syncMpiState(self): runLog.error("\n".join(msg)) raise - errors = collections.defaultdict( - list - ) # key is (comp, paramName) value is conflicting nodes + # key is (comp, paramName) value is conflicting nodes + errors = collections.defaultdict(list) syncCount = 0 compsPerNode = {len(nodeSyncData) for nodeSyncData in allSyncData} diff --git a/armi/reactor/tests/test_parameters.py b/armi/reactor/tests/test_parameters.py index 2302d4fb5..330171c10 100644 --- a/armi/reactor/tests/test_parameters.py +++ b/armi/reactor/tests/test_parameters.py @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests of the Parameters class.""" +from distutils.spawn import find_executable import copy -import traceback import unittest from armi import context from armi.reactor import composites from armi.reactor import parameters +# determine if this is a parallel run, and MPI is installed +MPI_EXE = None +if find_executable("mpiexec.exe") is not None: + MPI_EXE = "mpiexec.exe" +elif find_executable("mpiexec") is not None: + MPI_EXE = "mpiexec" + class MockComposite: def __init__(self, name): @@ -213,6 +220,13 @@ class Mock(parameters.ParameterCollection): self.assertEqual("encapsulated", mock.noSetter) def test_setter(self): + """Test the Parameter setter() tooling, that signifies if a Parameter has been updated. + + .. test:: Tooling that allows a Parameter to signal it needs to be updated across processes. + :id: T_ARMI_PARAM_PARALLEL0 + :tests: R_ARMI_PARAM_PARALLEL + """ + class Mock(parameters.ParameterCollection): pDefs = parameters.ParameterDefinitionCollection() with pDefs.createBuilder() as pb: @@ -241,6 +255,7 @@ def nPlus1(self, value): print(mock.n) with self.assertRaises(parameters.ParameterError): print(mock.nPlus1) + mock.n = 15 self.assertEqual(15, mock.n) self.assertEqual(16, mock.nPlus1) @@ -251,6 +266,13 @@ def nPlus1(self, value): self.assertTrue(all(pd.assigned for pd in mock.paramDefs)) def test_setterGetterBasics(self): + """Test the Parameter setter/getter tooling, through the lifecycle of a Parameter being updated. + + .. test:: Tooling that allows a Parameter to signal it needs to be updated across processes. + :id: T_ARMI_PARAM_PARALLEL1 + :tests: R_ARMI_PARAM_PARALLEL + """ + class Mock(parameters.ParameterCollection): pDefs = parameters.ParameterDefinitionCollection() with pDefs.createBuilder() as pb: @@ -445,9 +467,7 @@ class MockPC(parameters.ParameterCollection): self.assertEqual(set(pc.paramDefs.inCategory("bacon")), set([p2, p3])) def test_parameterCollectionsHave__slots__(self): - """Make sure something is implemented to prevent accidental creation of - attributes. - """ + """Tests we prevent accidental creation of attributes.""" self.assertEqual( set(["_hist", "_backup", "assigned", "_p_serialNum", "serialNum"]), set(parameters.ParameterCollection._slots), @@ -457,12 +477,6 @@ class MockPC(parameters.ParameterCollection): pass pc = MockPC() - # No longer protecting against __dict__ access. If someone REALLY wants to - # staple something to a parameter collection with no guarantees of anything, - # that's on them - # with self.assertRaises(AttributeError): - # pc.__dict__["foo"] = 5 - with self.assertRaises(AssertionError): pc.whatever = 22 @@ -507,230 +521,79 @@ class MockSyncPC(parameters.ParameterCollection): def makeComp(name): + """Helper method for MPI sync tests: mock up a Composite with a minimal param collections.""" c = composites.Composite(name) c.p = MockSyncPC() return c -class SynchronizationTests: - """Some unit tests that must be run with mpirun instead of the standard unittest - system. - """ +class SynchronizationTests(unittest.TestCase): + """Some tests that must be run with mpirun instead of the standard unittest system.""" def setUp(self): self.r = makeComp("reactor") self.r.core = makeComp("core") self.r.add(self.r.core) - for ai in range(context.MPI_SIZE * 4): + for ai in range(context.MPI_SIZE * 3): a = makeComp("assembly{}".format(ai)) self.r.core.add(a) - for bi in range(10): + for bi in range(3): a.add(makeComp("block{}-{}".format(ai, bi))) - self.comps = [self.r.core] + self.r.core.getChildren(deep=True) - for pd in MockSyncPC().paramDefs: - pd.assigned = parameters.NEVER - - def tearDown(self): - del self.r - - def run(self, testNamePrefix="mpitest_"): - with open("mpitest{}.temp".format(context.MPI_RANK), "w") as self.l: - for methodName in sorted(dir(self)): - if methodName.startswith(testNamePrefix): - self.write("{}.{}".format(self.__class__.__name__, methodName)) - try: - self.setUp() - getattr(self, methodName)() - except Exception: - self.write("failed, big time") - traceback.print_exc(file=self.l) - self.write("*** printed exception") - try: - self.tearDown() - except: # noqa: bare-except - pass - - self.l.write("done.") - - def write(self, msg): - self.l.write("{}\n".format(msg)) - self.l.flush() - - def assertRaises(self, exceptionType): - class ExceptionCatcher: - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_value, traceback): - if exc_type is exceptionType: - return True - raise AssertionError( - "Expected {}, but got {}".format(exceptionType, exc_type) - ) - return ExceptionCatcher() + self.comps = [self.r.core] + self.r.core.getChildren(deep=True) - def assertEqual(self, expected, actual): - if expected != actual: - raise AssertionError( - "(expected) {} != {} (actual)".format(expected, actual) - ) + @unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only") + def test_noConflicts(self): + """Make sure sync works across processes. - def assertNotEqual(self, expected, actual): - if expected == actual: - raise AssertionError( - "(expected) {} == {} (actual)".format(expected, actual) - ) + .. test:: Synchronize a reactor's state across processes. + :id: T_ARMI_CMP_MPI0 + :tests: R_ARMI_CMP_MPI + """ + _syncCount = self.r.syncMpiState() - def mpitest_noConflicts(self): for ci, comp in enumerate(self.comps): if ci % context.MPI_SIZE == context.MPI_RANK: comp.p.param1 = (context.MPI_RANK + 1) * 30.0 else: self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param1) - self.assertEqual(len(self.comps), self.r.syncMpiState()) + syncCount = self.r.syncMpiState() + self.assertEqual(len(self.comps), syncCount) for ci, comp in enumerate(self.comps): self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param1) - def mpitest_noConflicts_setByString(self): - """Make sure params set by string also work with sync.""" - for ci, comp in enumerate(self.comps): - if ci % context.MPI_SIZE == context.MPI_RANK: - comp.p.param2 = (context.MPI_RANK + 1) * 30.0 - else: - self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param2) - - self.assertEqual(len(self.comps), self.r.syncMpiState()) - - for ci, comp in enumerate(self.comps): - self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param2) + @unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only") + def test_withConflicts(self): + """Test conflicts arise correctly if we force a conflict. - def mpitest_withConflicts(self): + .. test:: Raise errors when there are conflicts across processes. + :id: T_ARMI_CMP_MPI1 + :tests: R_ARMI_CMP_MPI + """ self.r.core.p.param1 = (context.MPI_RANK + 1) * 99.0 with self.assertRaises(ValueError): self.r.syncMpiState() - def mpitest_withConflictsButSameValue(self): + @unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only") + def test_withConflictsButSameValue(self): + """Test that conflicts are ignored if the values are the same. + + .. test:: Don't raise errors when multiple processes make the same changes. + :id: T_ARMI_CMP_MPI2 + :tests: R_ARMI_CMP_MPI + """ self.r.core.p.param1 = (context.MPI_SIZE + 1) * 99.0 self.r.syncMpiState() self.assertEqual((context.MPI_SIZE + 1) * 99.0, self.r.core.p.param1) - def mpitest_noConflictsMaintainWithStateRetainer(self): - assigned = [] - with self.r.retainState(parameters.inCategory("cat1")): - for ci, comp in enumerate(self.comps): - comp.p.param2 = 99 * ci - if ci % context.MPI_SIZE == context.MPI_RANK: - comp.p.param1 = (context.MPI_RANK + 1) * 30.0 - assigned.append(parameters.SINCE_ANYTHING) - else: - self.assertNotEqual((context.MPI_RANK + 1) * 30.0, comp.p.param1) - assigned.append(parameters.NEVER) - - # 1st inside state retainer - self.assertEqual( - True, all(c.p.assigned == parameters.SINCE_ANYTHING for c in self.comps) - ) - - # confirm outside state retainer - self.assertEqual(assigned, [c.p.assigned for ci, c in enumerate(self.comps)]) - - # this rank's "assigned" components are not assigned on the workers, and so will - # be updated - self.assertEqual(len(self.comps), self.r.syncMpiState()) - - for ci, comp in enumerate(self.comps): - self.assertEqual((ci % context.MPI_SIZE + 1) * 30.0, comp.p.param1) - - def mpitest_conflictsMaintainWithStateRetainer(self): + @unittest.skipIf(context.MPI_SIZE <= 1 or MPI_EXE is None, "Parallel test only") + def test_conflictsMaintainWithStateRetainer(self): + """Test that the state retainer fails correctly when it should.""" with self.r.retainState(parameters.inCategory("cat2")): for _, comp in enumerate(self.comps): comp.p.param2 = 99 * context.MPI_RANK with self.assertRaises(ValueError): self.r.syncMpiState() - - def mpitest_rxCoeffsProcess(self): - """This test mimics the process for rxCoeffs when doing distributed doppler.""" - - def do(): - # we will do this over 4 passes (there are 4 * MPI_SIZE assemblies) - for passNum in range(4): - with self.r.retainState(parameters.inCategory("cat2")): - self.r.p.param3 = "hi" - for c in self.comps: - c.p.param1 = ( - 99 * context.MPI_RANK - ) # this will get reset after state retainer - a = self.r.core[passNum * context.MPI_SIZE + context.MPI_RANK] - a.p.param2 = context.MPI_RANK * 20.0 - for b in a: - b.p.param2 = context.MPI_RANK * 10.0 - - for ai, a2 in enumerate(self.r): - if ai % context.MPI_SIZE != context.MPI_RANK: - assert "param2" not in a2.p - - self.assertEqual(parameters.SINCE_ANYTHING, param1.assigned) - self.assertEqual(parameters.SINCE_ANYTHING, param2.assigned) - self.assertEqual(parameters.SINCE_ANYTHING, param3.assigned) - self.assertEqual(parameters.SINCE_ANYTHING, a.p.assigned) - - self.r.syncMpiState() - - self.assertEqual( - parameters.SINCE_ANYTHING - & ~parameters.SINCE_LAST_DISTRIBUTE_STATE, - param1.assigned, - ) - self.assertEqual( - parameters.SINCE_ANYTHING - & ~parameters.SINCE_LAST_DISTRIBUTE_STATE, - param2.assigned, - ) - self.assertEqual( - parameters.SINCE_ANYTHING - & ~parameters.SINCE_LAST_DISTRIBUTE_STATE, - param3.assigned, - ) - self.assertEqual( - parameters.SINCE_ANYTHING - & ~parameters.SINCE_LAST_DISTRIBUTE_STATE, - a.p.assigned, - ) - - self.assertEqual(parameters.NEVER, param1.assigned) - self.assertEqual(parameters.SINCE_ANYTHING, param2.assigned) - self.assertEqual(parameters.NEVER, param3.assigned) - self.assertEqual(parameters.SINCE_ANYTHING, a.p.assigned) - do_assert(passNum) - - param1 = self.r.p.paramDefs["param1"] - param2 = self.r.p.paramDefs["param2"] - param3 = self.r.p.paramDefs["param3"] - - def do_assert(passNum): - # ensure all assemblies and blocks set values for param2, but param1 is - # empty - for rank in range(context.MPI_SIZE): - a = self.r.core[passNum * context.MPI_SIZE + rank] - assert "param1" not in a.p - assert "param3" not in a.p - self.assertEqual(rank * 20, a.p.param2) - for b in a: - self.assertEqual(rank * 10, b.p.param2) - assert "param1" not in b.p - assert "param3" not in b.p - - if context.MPI_RANK == 0: - with self.r.retainState(parameters.inCategory("cat2")): - context.MPI_COMM.bcast(self.r) - do() - [do_assert(passNum) for passNum in range(4)] - [do_assert(passNum) for passNum in range(4)] - else: - del self.r - self.r = context.MPI_COMM.bcast(None) - do()