diff --git a/src/RepoAuditor/Executor.py b/src/RepoAuditor/Executor.py new file mode 100644 index 0000000..190d95d --- /dev/null +++ b/src/RepoAuditor/Executor.py @@ -0,0 +1,203 @@ +# ------------------------------------------------------------------------------- +# | | +# | Copyright (c) 2024 Scientific Software Engineering Center at Georgia Tech | +# | Distributed under the MIT License. | +# | | +# ------------------------------------------------------------------------------- +"""Executes one or more Modules""" + +import sys + +from typing import Any, Callable, Optional + +from dbrownell_Common import ExecuteTasks # type: ignore[import-untyped] +from dbrownell_Common.InflectEx import inflect # type: ignore[import-untyped] +from dbrownell_Common.Streams.Capabilities import Capabilities # type: ignore[import-untyped] +from dbrownell_Common.Streams.DoneManager import DoneManager # type: ignore[import-untyped] +from rich.progress import Progress, TimeElapsedColumn + +from .Module import EvaluateResult, ExecutionStyle, Module + + +# ---------------------------------------------------------------------- +def Execute( + dm: DoneManager, + modules: list[Module], + *, + warnings_as_errors_module_names: Optional[set[str]] = None, + ignore_warnings_module_names: Optional[set[str]] = None, + max_num_threads: Optional[int] = None, +) -> None: + warnings_as_errors_module_names = warnings_as_errors_module_names or set() + ignore_warnings_module_names = ignore_warnings_module_names or set() + + with dm.Nested("Processing {}...".format(inflect.no("module", len(modules)))) as modules_dm: + parallel: list[tuple[int, Module]] = [] + sequential: list[tuple[int, Module]] = [] + + for index, module in enumerate(modules): + if module.style == ExecutionStyle.Parallel: + parallel.append((index, module)) + elif module.style == ExecutionStyle.Sequential: + sequential.append((index, module)) + else: + assert False, module.style # pragma: no cover + + # Calculate the results + + # ---------------------------------------------------------------------- + def CreateStatusString( + num_completed: int, # pylint: disable=unused-argument + num_success: int, + num_error: int, + num_warning: int, + num_does_not_apply: int, + ) -> str: + return f"✅: {num_success} ❌: {num_error} ⚠️: {num_warning} 🚫: {num_does_not_apply}" + + # ---------------------------------------------------------------------- + def CalcResultInfo( + module: Module, + eval_infos: list[list[Module.EvaluateInfo]], + ) -> tuple[int, str]: + return_code = 0 + + for eval_info_items in eval_infos: + for eval_info in eval_info_items: + result = eval_info.result + if result == EvaluateResult.Warning: + if module.name in warnings_as_errors_module_names: + result = EvaluateResult.Error + elif module.name in ignore_warnings_module_names: + continue + + if result == EvaluateResult.Error: + return -1, "errors were encountered" + elif result == EvaluateResult.Warning: + return_code = 1 + + return return_code, "" if return_code == 0 else "warnings were encountered" + + # ---------------------------------------------------------------------- + + results: list[Optional[list[list[Module.EvaluateInfo]]]] = [None] * len(modules) + + if parallel: + # ---------------------------------------------------------------------- + def Prepare( + context: Any, + on_simple_status_func: Callable[[str], None], # pylint: disable=unused-argument + ) -> tuple[int, ExecuteTasks.TransformTasksExTypes.TransformFuncType]: + module = context + del context + + # ---------------------------------------------------------------------- + def Transform( + status: ExecuteTasks.Status, + ) -> ExecuteTasks.TransformResultComplete: + + # ---------------------------------------------------------------------- + def OnStatus(num_completed: int, *args, **kwargs): + status.OnProgress( + num_completed, CreateStatusString(num_completed, *args, **kwargs) + ) + + # ---------------------------------------------------------------------- + + result: list[list[Module.EvaluateInfo]] = module.Evaluate( + OnStatus, + max_num_threads=max_num_threads, + ) + + result_code, result_status = CalcResultInfo(module, result) + + return ExecuteTasks.TransformResultComplete(result, result_code, result_status) + + # ---------------------------------------------------------------------- + + return module.GetNumRequirements(), Transform + + # ---------------------------------------------------------------------- + + for (results_index, _), result in zip( + parallel, + ExecuteTasks.TransformTasksEx( + modules_dm, + "Processing parallel modules...", + [ExecuteTasks.TaskData(module.name, module) for _, module in parallel], + Prepare, + max_num_threads=max_num_threads, + ), + ): + assert results[results_index] is None + assert isinstance(result, list), result + + results[results_index] = result + + for index, (results_index, module) in enumerate(sequential): + with modules_dm.Nested( + "Processing '{}' ({} of {})...".format( + module.name, + index + 1 + len(parallel), + len(modules), + ), + ) as this_module_dm: + with this_module_dm.YieldStdout() as stdout_context: + stdout_context.persist_content = False + + # Technically speaking, it would be more correct to use `stdout_context.stream` here + # rather than referencing `sys.stdout` directly, but it is really hard to work with mocked + # stream as mocks will create mocks for everything called on the mock. Use sys.stdout + # directly to avoid that particular problem. + from unittest.mock import Mock, MagicMock + + assert stdout_context.stream is sys.stdout or isinstance( + stdout_context.stream, (Mock, MagicMock) + ), stdout_context.stream + + with Progress( + *Progress.get_default_columns(), + TimeElapsedColumn(), + "{task.fields[status]}", + console=Capabilities.Get(sys.stdout).CreateRichConsole(sys.stdout), + transient=True, + refresh_per_second=10, + ) as progress_bar: + progress_bar_task_id = progress_bar.add_task( + stdout_context.line_prefix, + status="", + total=module.GetNumRequirements(), + visible=True, + ) + + # ---------------------------------------------------------------------- + def OnStatus( + num_completed: int, + num_success: int, + num_error: int, + num_warning: int, + num_does_not_apply: int, + ) -> None: + progress_bar.update( + progress_bar_task_id, + completed=num_completed, + status=CreateStatusString( + num_completed, + num_success, + num_error, + num_warning, + num_does_not_apply, + ), + ) + + # ---------------------------------------------------------------------- + + this_results: list[list[Module.EvaluateInfo]] = module.Evaluate( + OnStatus, + max_num_threads=max_num_threads, + ) + + assert results[results_index] is None + results[results_index] = this_results + + this_module_dm.result = CalcResultInfo(module, this_results)[0] diff --git a/tests/Executor_UnitTest.py b/tests/Executor_UnitTest.py new file mode 100644 index 0000000..1801625 --- /dev/null +++ b/tests/Executor_UnitTest.py @@ -0,0 +1,188 @@ +# ------------------------------------------------------------------------------- +# | | +# | Copyright (c) 2024 Scientific Software Engineering Center at Georgia Tech | +# | Distributed under the MIT License. | +# | | +# ------------------------------------------------------------------------------- +"""Unit test for Executor.py""" + +import sys +import time + +from dataclasses import dataclass + +import pytest + +from dbrownell_Common.Types import override + +from RepoAuditor.Executor import * +from RepoAuditor.Module import * +from RepoAuditor.Requirement import * + + +# ---------------------------------------------------------------------- +@dataclass(frozen=True) +class MyModule(Module): + @override + def _GetData(self) -> Optional[dict[str, Any]]: + return {"module_name": self.name} + + +# ---------------------------------------------------------------------- +@dataclass(frozen=True) +class MyQuery(Query): + @override + def GetData( + self, + module_data: dict[str, Any], + ) -> Optional[dict[str, Any]]: + module_data["query_name"] = self.name + return module_data + + +# ---------------------------------------------------------------------- +@dataclass(frozen=True) +class MyRequirement(Requirement): + result: EvaluateResult + + # ---------------------------------------------------------------------- + # ---------------------------------------------------------------------- + # ---------------------------------------------------------------------- + @override + def _EvaluateImpl( + self, + query_data: dict[str, Any], + ) -> tuple[EvaluateResult, Optional[str]]: + # Introduce a delay so that we can see things happening + time.sleep(0.1) + + return self.result, None + + +# ---------------------------------------------------------------------- +@pytest.mark.parametrize("single_threaded", [False, True]) +def test_Successful(single_threaded): + modules: list[Module] = [] + + # ---------------------------------------------------------------------- + def GetExecutionStyle(index: int) -> ExecutionStyle: + return ExecutionStyle.Parallel if index % 2 == 0 else ExecutionStyle.Sequential + + # ---------------------------------------------------------------------- + + for module_index in range(5): + queries: list[Query] = [] + + for query_index in range(4): + requirements: list[Requirement] = [] + + for requirement_index in range(5): + requirements.append( + MyRequirement( + f"Requirement-{module_index}-{query_index}-{requirement_index}", + "", + GetExecutionStyle(requirement_index), + "", + "", + EvaluateResult.Success, + ), + ) + + queries.append( + MyQuery( + f"Query-{module_index}-{query_index}", + "", + GetExecutionStyle(query_index), + requirements, + ), + ) + + modules.append( + MyModule( + f"Module-{module_index}", + "", + GetExecutionStyle(module_index), + queries, + ), + ) + + with DoneManager.Create(sys.stdout, "", line_prefix="") as dm: + Execute( + dm, + modules, + max_num_threads=1 if single_threaded else None, + ) + + assert dm.result == 0 + + +# ---------------------------------------------------------------------- +@pytest.mark.parametrize( + "data", + [ + (EvaluateResult.Error, -1, False, False), + (EvaluateResult.Warning, 1, False, False), + (EvaluateResult.Warning, -1, True, False), + (EvaluateResult.Warning, 0, False, True), + ], +) +def test_NotSuccess(data): + test_result, expected_result, warnings_as_errors, ignore_warnings = data + + modules: list[Module] = [] + + # ---------------------------------------------------------------------- + def GetExecutionStyle(index: int) -> ExecutionStyle: + return ExecutionStyle.Parallel if index % 2 == 0 else ExecutionStyle.Sequential + + # ---------------------------------------------------------------------- + + for module_index in range(5): + queries: list[Query] = [] + + for query_index in range(4): + requirements: list[Requirement] = [] + + for requirement_index in range(5): + requirements.append( + MyRequirement( + f"Requirement-{module_index}-{query_index}-{requirement_index}", + "", + GetExecutionStyle(requirement_index), + "", + "", + (test_result if requirement_index % 3 == 0 else EvaluateResult.Success), + ), + ) + + queries.append( + MyQuery( + f"Query-{module_index}-{query_index}", + "", + GetExecutionStyle(query_index), + requirements, + ), + ) + + modules.append( + MyModule( + f"Module-{module_index}", + "", + GetExecutionStyle(module_index), + queries, + ), + ) + + with DoneManager.Create(sys.stdout, "", line_prefix="") as dm: + Execute( + dm, + modules, + warnings_as_errors_module_names=( + set() if not warnings_as_errors else {module.name for module in modules} + ), + ignore_warnings_module_names=( + set() if not ignore_warnings else {module.name for module in modules} + ), + ) + + assert dm.result == expected_result