-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
391 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# ------------------------------------------------------------------------------- | ||
# | | | ||
# | Copyright (c) 2024 Scientific Software Engineering Center at Georgia Tech | | ||
# | Distributed under the MIT License. | | ||
# | | | ||
# ------------------------------------------------------------------------------- | ||
"""Executes one or more Modules""" | ||
|
||
import sys | ||
|
||
from typing import Any, Callable, Optional | ||
|
||
from dbrownell_Common import ExecuteTasks # type: ignore[import-untyped] | ||
from dbrownell_Common.InflectEx import inflect # type: ignore[import-untyped] | ||
from dbrownell_Common.Streams.Capabilities import Capabilities # type: ignore[import-untyped] | ||
from dbrownell_Common.Streams.DoneManager import DoneManager # type: ignore[import-untyped] | ||
from rich.progress import Progress, TimeElapsedColumn | ||
|
||
from .Module import EvaluateResult, ExecutionStyle, Module | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
def Execute( | ||
dm: DoneManager, | ||
modules: list[Module], | ||
*, | ||
warnings_as_errors_module_names: Optional[set[str]] = None, | ||
ignore_warnings_module_names: Optional[set[str]] = None, | ||
max_num_threads: Optional[int] = None, | ||
) -> None: | ||
warnings_as_errors_module_names = warnings_as_errors_module_names or set() | ||
ignore_warnings_module_names = ignore_warnings_module_names or set() | ||
|
||
with dm.Nested("Processing {}...".format(inflect.no("module", len(modules)))) as modules_dm: | ||
parallel: list[tuple[int, Module]] = [] | ||
sequential: list[tuple[int, Module]] = [] | ||
|
||
for index, module in enumerate(modules): | ||
if module.style == ExecutionStyle.Parallel: | ||
parallel.append((index, module)) | ||
elif module.style == ExecutionStyle.Sequential: | ||
sequential.append((index, module)) | ||
else: | ||
assert False, module.style # pragma: no cover | ||
|
||
# Calculate the results | ||
|
||
# ---------------------------------------------------------------------- | ||
def CreateStatusString( | ||
num_completed: int, # pylint: disable=unused-argument | ||
num_success: int, | ||
num_error: int, | ||
num_warning: int, | ||
num_does_not_apply: int, | ||
) -> str: | ||
return f"✅: {num_success} ❌: {num_error} ⚠️: {num_warning} 🚫: {num_does_not_apply}" | ||
|
||
# ---------------------------------------------------------------------- | ||
def CalcResultInfo( | ||
module: Module, | ||
eval_infos: list[list[Module.EvaluateInfo]], | ||
) -> tuple[int, str]: | ||
return_code = 0 | ||
|
||
for eval_info_items in eval_infos: | ||
for eval_info in eval_info_items: | ||
result = eval_info.result | ||
if result == EvaluateResult.Warning: | ||
if module.name in warnings_as_errors_module_names: | ||
result = EvaluateResult.Error | ||
elif module.name in ignore_warnings_module_names: | ||
continue | ||
|
||
if result == EvaluateResult.Error: | ||
return -1, "errors were encountered" | ||
elif result == EvaluateResult.Warning: | ||
return_code = 1 | ||
|
||
return return_code, "" if return_code == 0 else "warnings were encountered" | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
results: list[Optional[list[list[Module.EvaluateInfo]]]] = [None] * len(modules) | ||
|
||
if parallel: | ||
# ---------------------------------------------------------------------- | ||
def Prepare( | ||
context: Any, | ||
on_simple_status_func: Callable[[str], None], # pylint: disable=unused-argument | ||
) -> tuple[int, ExecuteTasks.TransformTasksExTypes.TransformFuncType]: | ||
module = context | ||
del context | ||
|
||
# ---------------------------------------------------------------------- | ||
def Transform( | ||
status: ExecuteTasks.Status, | ||
) -> ExecuteTasks.TransformResultComplete: | ||
|
||
# ---------------------------------------------------------------------- | ||
def OnStatus(num_completed: int, *args, **kwargs): | ||
status.OnProgress( | ||
num_completed, CreateStatusString(num_completed, *args, **kwargs) | ||
) | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
result: list[list[Module.EvaluateInfo]] = module.Evaluate( | ||
OnStatus, | ||
max_num_threads=max_num_threads, | ||
) | ||
|
||
result_code, result_status = CalcResultInfo(module, result) | ||
|
||
return ExecuteTasks.TransformResultComplete(result, result_code, result_status) | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
return module.GetNumRequirements(), Transform | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
for (results_index, _), result in zip( | ||
parallel, | ||
ExecuteTasks.TransformTasksEx( | ||
modules_dm, | ||
"Processing parallel modules...", | ||
[ExecuteTasks.TaskData(module.name, module) for _, module in parallel], | ||
Prepare, | ||
max_num_threads=max_num_threads, | ||
), | ||
): | ||
assert results[results_index] is None | ||
assert isinstance(result, list), result | ||
|
||
results[results_index] = result | ||
|
||
for index, (results_index, module) in enumerate(sequential): | ||
with modules_dm.Nested( | ||
"Processing '{}' ({} of {})...".format( | ||
module.name, | ||
index + 1 + len(parallel), | ||
len(modules), | ||
), | ||
) as this_module_dm: | ||
with this_module_dm.YieldStdout() as stdout_context: | ||
stdout_context.persist_content = False | ||
|
||
# Technically speaking, it would be more correct to use `stdout_context.stream` here | ||
# rather than referencing `sys.stdout` directly, but it is really hard to work with mocked | ||
# stream as mocks will create mocks for everything called on the mock. Use sys.stdout | ||
# directly to avoid that particular problem. | ||
from unittest.mock import Mock, MagicMock | ||
|
||
assert stdout_context.stream is sys.stdout or isinstance( | ||
stdout_context.stream, (Mock, MagicMock) | ||
), stdout_context.stream | ||
|
||
with Progress( | ||
*Progress.get_default_columns(), | ||
TimeElapsedColumn(), | ||
"{task.fields[status]}", | ||
console=Capabilities.Get(sys.stdout).CreateRichConsole(sys.stdout), | ||
transient=True, | ||
refresh_per_second=10, | ||
) as progress_bar: | ||
progress_bar_task_id = progress_bar.add_task( | ||
stdout_context.line_prefix, | ||
status="", | ||
total=module.GetNumRequirements(), | ||
visible=True, | ||
) | ||
|
||
# ---------------------------------------------------------------------- | ||
def OnStatus( | ||
num_completed: int, | ||
num_success: int, | ||
num_error: int, | ||
num_warning: int, | ||
num_does_not_apply: int, | ||
) -> None: | ||
progress_bar.update( | ||
progress_bar_task_id, | ||
completed=num_completed, | ||
status=CreateStatusString( | ||
num_completed, | ||
num_success, | ||
num_error, | ||
num_warning, | ||
num_does_not_apply, | ||
), | ||
) | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
this_results: list[list[Module.EvaluateInfo]] = module.Evaluate( | ||
OnStatus, | ||
max_num_threads=max_num_threads, | ||
) | ||
|
||
assert results[results_index] is None | ||
results[results_index] = this_results | ||
|
||
this_module_dm.result = CalcResultInfo(module, this_results)[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
# ------------------------------------------------------------------------------- | ||
# | | | ||
# | Copyright (c) 2024 Scientific Software Engineering Center at Georgia Tech | | ||
# | Distributed under the MIT License. | | ||
# | | | ||
# ------------------------------------------------------------------------------- | ||
"""Unit test for Executor.py""" | ||
|
||
import sys | ||
import time | ||
|
||
from dataclasses import dataclass | ||
|
||
import pytest | ||
|
||
from dbrownell_Common.Types import override | ||
|
||
from RepoAuditor.Executor import * | ||
from RepoAuditor.Module import * | ||
from RepoAuditor.Requirement import * | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
@dataclass(frozen=True) | ||
class MyModule(Module): | ||
@override | ||
def _GetData(self) -> Optional[dict[str, Any]]: | ||
return {"module_name": self.name} | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
@dataclass(frozen=True) | ||
class MyQuery(Query): | ||
@override | ||
def GetData( | ||
self, | ||
module_data: dict[str, Any], | ||
) -> Optional[dict[str, Any]]: | ||
module_data["query_name"] = self.name | ||
return module_data | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
@dataclass(frozen=True) | ||
class MyRequirement(Requirement): | ||
result: EvaluateResult | ||
|
||
# ---------------------------------------------------------------------- | ||
# ---------------------------------------------------------------------- | ||
# ---------------------------------------------------------------------- | ||
@override | ||
def _EvaluateImpl( | ||
self, | ||
query_data: dict[str, Any], | ||
) -> tuple[EvaluateResult, Optional[str]]: | ||
# Introduce a delay so that we can see things happening | ||
time.sleep(0.1) | ||
|
||
return self.result, None | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
@pytest.mark.parametrize("single_threaded", [False, True]) | ||
def test_Successful(single_threaded): | ||
modules: list[Module] = [] | ||
|
||
# ---------------------------------------------------------------------- | ||
def GetExecutionStyle(index: int) -> ExecutionStyle: | ||
return ExecutionStyle.Parallel if index % 2 == 0 else ExecutionStyle.Sequential | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
for module_index in range(5): | ||
queries: list[Query] = [] | ||
|
||
for query_index in range(4): | ||
requirements: list[Requirement] = [] | ||
|
||
for requirement_index in range(5): | ||
requirements.append( | ||
MyRequirement( | ||
f"Requirement-{module_index}-{query_index}-{requirement_index}", | ||
"", | ||
GetExecutionStyle(requirement_index), | ||
"", | ||
"", | ||
EvaluateResult.Success, | ||
), | ||
) | ||
|
||
queries.append( | ||
MyQuery( | ||
f"Query-{module_index}-{query_index}", | ||
"", | ||
GetExecutionStyle(query_index), | ||
requirements, | ||
), | ||
) | ||
|
||
modules.append( | ||
MyModule( | ||
f"Module-{module_index}", | ||
"", | ||
GetExecutionStyle(module_index), | ||
queries, | ||
), | ||
) | ||
|
||
with DoneManager.Create(sys.stdout, "", line_prefix="") as dm: | ||
Execute( | ||
dm, | ||
modules, | ||
max_num_threads=1 if single_threaded else None, | ||
) | ||
|
||
assert dm.result == 0 | ||
|
||
|
||
# ---------------------------------------------------------------------- | ||
@pytest.mark.parametrize( | ||
"data", | ||
[ | ||
(EvaluateResult.Error, -1, False, False), | ||
(EvaluateResult.Warning, 1, False, False), | ||
(EvaluateResult.Warning, -1, True, False), | ||
(EvaluateResult.Warning, 0, False, True), | ||
], | ||
) | ||
def test_NotSuccess(data): | ||
test_result, expected_result, warnings_as_errors, ignore_warnings = data | ||
|
||
modules: list[Module] = [] | ||
|
||
# ---------------------------------------------------------------------- | ||
def GetExecutionStyle(index: int) -> ExecutionStyle: | ||
return ExecutionStyle.Parallel if index % 2 == 0 else ExecutionStyle.Sequential | ||
|
||
# ---------------------------------------------------------------------- | ||
|
||
for module_index in range(5): | ||
queries: list[Query] = [] | ||
|
||
for query_index in range(4): | ||
requirements: list[Requirement] = [] | ||
|
||
for requirement_index in range(5): | ||
requirements.append( | ||
MyRequirement( | ||
f"Requirement-{module_index}-{query_index}-{requirement_index}", | ||
"", | ||
GetExecutionStyle(requirement_index), | ||
"", | ||
"", | ||
(test_result if requirement_index % 3 == 0 else EvaluateResult.Success), | ||
), | ||
) | ||
|
||
queries.append( | ||
MyQuery( | ||
f"Query-{module_index}-{query_index}", | ||
"", | ||
GetExecutionStyle(query_index), | ||
requirements, | ||
), | ||
) | ||
|
||
modules.append( | ||
MyModule( | ||
f"Module-{module_index}", | ||
"", | ||
GetExecutionStyle(module_index), | ||
queries, | ||
), | ||
) | ||
|
||
with DoneManager.Create(sys.stdout, "", line_prefix="") as dm: | ||
Execute( | ||
dm, | ||
modules, | ||
warnings_as_errors_module_names=( | ||
set() if not warnings_as_errors else {module.name for module in modules} | ||
), | ||
ignore_warnings_module_names=( | ||
set() if not ignore_warnings else {module.name for module in modules} | ||
), | ||
) | ||
|
||
assert dm.result == expected_result |