diff --git a/src/AMSWorkflow/ams/monitor.py b/src/AMSWorkflow/ams/monitor.py index 903ec3e1..13542bf2 100644 --- a/src/AMSWorkflow/ams/monitor.py +++ b/src/AMSWorkflow/ams/monitor.py @@ -119,10 +119,12 @@ def __str__(self) -> str: def __repr__(self) -> str: return self.__str__() - def lock(self): + @staticmethod + def lock(): AMSMonitor._lock.acquire() - def unlock(self): + @staticmethod + def unlock(): AMSMonitor._lock.release() def __enter__(self): @@ -153,12 +155,12 @@ def info(cls) -> str: @classmethod @property def stats(cls): - return AMSMonitor._stats + return cls._stats @classmethod @property def format_ts(cls): - return AMSMonitor._ts_format + return cls._ts_format @classmethod def convert_ts(cls, ts: str) -> datetime.datetime: @@ -175,6 +177,12 @@ def json(cls, json_output: str): # To avoid partial line at the end of the file fp.write("\n") + @classmethod + def reset(cls): + cls.lock() + cls._stats = cls._manager.dict() + cls.unlock() + def start_monitor(self, *args, **kwargs): self.start_time = time.time() self.internal_ts = datetime.datetime.now().strftime(self._ts_format) @@ -229,7 +237,7 @@ def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str): """ This function update the hashmap containing all the records. """ - self.lock() + AMSMonitor.lock() if class_name not in AMSMonitor._stats: AMSMonitor._stats[class_name] = {} @@ -249,7 +257,7 @@ def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str): temp[func_name][ts][k] = v # This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory) AMSMonitor._stats[class_name] = temp - self.unlock() + AMSMonitor.unlock() def _remove_reserved_keys(self, d: Union[dict, List]) -> dict: for key in self._reserved_keys: diff --git a/tests/AMSWorkflow/test_amsmonitor.py b/tests/AMSWorkflow/test_amsmonitor.py new file mode 100644 index 00000000..e13aa3d3 --- /dev/null +++ b/tests/AMSWorkflow/test_amsmonitor.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import datetime +import json +import os +import time +import unittest + +from ams.monitor import AMSMonitor +from ams.stage import Task + + +class ExampleTask1(Task): + def __init__(self): + self.x = 0 + self.y = 100 + + @AMSMonitor() + def __call__(self): + i = 0 + with AMSMonitor(obj=self, record=["x"], tag="while_loop", accumulate=False): + while 1: + time.sleep(1) + self.x += i + if i == 3: + break + i += 1 + self.y += 100 + + +def read_json(path: str): + with open(path) as f: + d = json.load(f) + return d + + +class TestMonitorTask1(unittest.TestCase): + def setUp(self): + self.task1 = ExampleTask1() + + def test_populating_monitor(self): + AMSMonitor.reset() + self.task1() + + self.assertNotEqual(AMSMonitor.stats.copy(), {}) + self.assertIn("ExampleTask1", AMSMonitor.stats) + self.assertIn("while_loop", AMSMonitor.stats["ExampleTask1"]) + self.assertIn("__call__", AMSMonitor.stats["ExampleTask1"]) + + for ts in AMSMonitor.stats["ExampleTask1"]["__call__"].keys(): + self.assertIsInstance(datetime.datetime.strptime(ts, AMSMonitor.format_ts), datetime.datetime) + self.assertIn("x", AMSMonitor.stats["ExampleTask1"]["__call__"][ts]) + self.assertIn("y", AMSMonitor.stats["ExampleTask1"]["__call__"][ts]) + self.assertIn("amsmonitor_duration", AMSMonitor.stats["ExampleTask1"]["__call__"][ts]) + self.assertEqual(AMSMonitor.stats["ExampleTask1"]["__call__"][ts]["x"], 6) + self.assertEqual(AMSMonitor.stats["ExampleTask1"]["__call__"][ts]["y"], 200) + + for ts in AMSMonitor.stats["ExampleTask1"]["while_loop"].keys(): + self.assertIsInstance(datetime.datetime.strptime(ts, AMSMonitor.format_ts), datetime.datetime) + self.assertIn("x", AMSMonitor.stats["ExampleTask1"]["while_loop"][ts]) + self.assertIn("amsmonitor_duration", AMSMonitor.stats["ExampleTask1"]["while_loop"][ts]) + self.assertEqual(AMSMonitor.stats["ExampleTask1"]["while_loop"][ts]["x"], 6) + + def test_json_output(self): + print(f"test_json_output {AMSMonitor.stats.copy()}") + AMSMonitor.reset() + self.task1() + path = "test_amsmonitor.json" + AMSMonitor.json(path) + self.assertTrue(os.path.isfile(path)) + d = read_json(path) + self.assertEqual(AMSMonitor.stats.copy(), d) + + def tearDown(self): + try: + os.remove("test_amsmonitor.json") + except OSError: + pass + + +if __name__ == "__main__": + unittest.main()