Skip to content

Commit

Permalink
Added basic tests for AMSMonitor (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Loic Pottier <pottier1@llnl.gov>
  • Loading branch information
lpottier authored Apr 9, 2024
1 parent 41687c2 commit efbbd22
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/AMSWorkflow/ams/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.__str__()

def lock(self):
@staticmethod
def lock():
AMSMonitor._lock.acquire()

def unlock(self):
@staticmethod
def unlock():
AMSMonitor._lock.release()

def __enter__(self):
Expand Down Expand Up @@ -153,12 +155,12 @@ def info(cls) -> str:
@classmethod
@property
def stats(cls):
return AMSMonitor._stats
return cls._stats

@classmethod
@property
def format_ts(cls):
return AMSMonitor._ts_format
return cls._ts_format

@classmethod
def convert_ts(cls, ts: str) -> datetime.datetime:
Expand All @@ -175,6 +177,12 @@ def json(cls, json_output: str):
# To avoid partial line at the end of the file
fp.write("\n")

@classmethod
def reset(cls):
cls.lock()
cls._stats = cls._manager.dict()
cls.unlock()

def start_monitor(self, *args, **kwargs):
self.start_time = time.time()
self.internal_ts = datetime.datetime.now().strftime(self._ts_format)
Expand Down Expand Up @@ -229,7 +237,7 @@ def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
"""
This function update the hashmap containing all the records.
"""
self.lock()
AMSMonitor.lock()
if class_name not in AMSMonitor._stats:
AMSMonitor._stats[class_name] = {}

Expand All @@ -249,7 +257,7 @@ def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
temp[func_name][ts][k] = v
# This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory)
AMSMonitor._stats[class_name] = temp
self.unlock()
AMSMonitor.unlock()

def _remove_reserved_keys(self, d: Union[dict, List]) -> dict:
for key in self._reserved_keys:
Expand Down
87 changes: 87 additions & 0 deletions tests/AMSWorkflow/test_amsmonitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#!/usr/bin/env python3
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


import datetime
import json
import os
import time
import unittest

from ams.monitor import AMSMonitor
from ams.stage import Task


class ExampleTask1(Task):
def __init__(self):
self.x = 0
self.y = 100

@AMSMonitor()
def __call__(self):
i = 0
with AMSMonitor(obj=self, record=["x"], tag="while_loop", accumulate=False):
while 1:
time.sleep(1)
self.x += i
if i == 3:
break
i += 1
self.y += 100


def read_json(path: str):
with open(path) as f:
d = json.load(f)
return d


class TestMonitorTask1(unittest.TestCase):
def setUp(self):
self.task1 = ExampleTask1()

def test_populating_monitor(self):
AMSMonitor.reset()
self.task1()

self.assertNotEqual(AMSMonitor.stats.copy(), {})
self.assertIn("ExampleTask1", AMSMonitor.stats)
self.assertIn("while_loop", AMSMonitor.stats["ExampleTask1"])
self.assertIn("__call__", AMSMonitor.stats["ExampleTask1"])

for ts in AMSMonitor.stats["ExampleTask1"]["__call__"].keys():
self.assertIsInstance(datetime.datetime.strptime(ts, AMSMonitor.format_ts), datetime.datetime)
self.assertIn("x", AMSMonitor.stats["ExampleTask1"]["__call__"][ts])
self.assertIn("y", AMSMonitor.stats["ExampleTask1"]["__call__"][ts])
self.assertIn("amsmonitor_duration", AMSMonitor.stats["ExampleTask1"]["__call__"][ts])
self.assertEqual(AMSMonitor.stats["ExampleTask1"]["__call__"][ts]["x"], 6)
self.assertEqual(AMSMonitor.stats["ExampleTask1"]["__call__"][ts]["y"], 200)

for ts in AMSMonitor.stats["ExampleTask1"]["while_loop"].keys():
self.assertIsInstance(datetime.datetime.strptime(ts, AMSMonitor.format_ts), datetime.datetime)
self.assertIn("x", AMSMonitor.stats["ExampleTask1"]["while_loop"][ts])
self.assertIn("amsmonitor_duration", AMSMonitor.stats["ExampleTask1"]["while_loop"][ts])
self.assertEqual(AMSMonitor.stats["ExampleTask1"]["while_loop"][ts]["x"], 6)

def test_json_output(self):
print(f"test_json_output {AMSMonitor.stats.copy()}")
AMSMonitor.reset()
self.task1()
path = "test_amsmonitor.json"
AMSMonitor.json(path)
self.assertTrue(os.path.isfile(path))
d = read_json(path)
self.assertEqual(AMSMonitor.stats.copy(), d)

def tearDown(self):
try:
os.remove("test_amsmonitor.json")
except OSError:
pass


if __name__ == "__main__":
unittest.main()

0 comments on commit efbbd22

Please sign in to comment.