Skip to content

Commit

Permalink
fix: failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
dPys committed Dec 27, 2024
1 parent 08c3413 commit 239b1ab
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
2 changes: 1 addition & 1 deletion nxbench/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def teardown_networkx():
# ---- Nx-Parallel backend ----
def convert_parallel(original_graph: nx.Graph, num_threads: int):
nxp = import_module("nx_parallel")
from joblib import cpu_count
from multiprocessing import cpu_count

total_cores = cpu_count()

Expand Down
46 changes: 18 additions & 28 deletions nxbench/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,65 +87,55 @@ def configure_backend(original_graph: nx.Graph, backend: str, num_thread: int) -
def run_algorithm(
graph: Any, algo_config: AlgorithmConfig, num_thread: int, backend: str
) -> tuple[Any, float, int, str | None]:
"""
Attempt to run the algorithm on the configured backend, but gracefully
skip if that algorithm is not actually implemented for this backend.
"""
"""Run the algorithm on the configured backend"""
logger = get_run_logger()

try:
# retrieve the callable for the chosen backend
algo_func = algo_config.get_callable(backend)
except ImportError as e:
logger.exception(
f"Could not get a callable for {algo_config.name} from {backend}."
)
return None, 0.0, 0, str(e)

# parse and prepare the parameters
pos_args, kwargs = process_algorithm_params(algo_config.params)

kwargs = add_seeding(kwargs, algo_func, algo_config.name)

error = None
original_env = {}
vars_to_set = [
"NUM_THREAD",
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
]
try:
original_env = {}
vars_to_set = [
"NUM_THREAD",
"OMP_NUM_THREADS",
"MKL_NUM_THREADS",
"OPENBLAS_NUM_THREADS",
"NUMEXPR_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS",
]
for var_name in vars_to_set:
original_env[var_name] = os.environ.get(var_name)
os.environ[var_name] = str(num_thread)

with memory_tracker() as mem:
start_time = time.perf_counter()

try:
result = algo_func(graph, *pos_args, **kwargs)
except NotImplementedError as nie:
logger.info(
f"Skipping {algo_config.name} for backend '{backend}' because "
f"it's not implemented (NotImplementedError)."
)
return None, 0.0, 0, str(nie)

# pass the graph plus the processed pos_args and kwargs
result = algo_func(graph, *pos_args, **kwargs)
end_time = time.perf_counter()
execution_time = end_time - start_time
peak_memory = mem["peak"]

execution_time = end_time - start_time
peak_memory = mem["peak"]
logger.debug(f"Algorithm '{algo_config.name}' executed successfully.")

except Exception as e:
logger.exception("Algorithm run failed")
execution_time = time.perf_counter() - start_time
peak_memory = mem.get("peak", 0)
result = None
error = str(e)

finally:
# restore environment
# restore environment variables
for var_name in vars_to_set:
if original_env[var_name] is None:
del os.environ[var_name]
Expand Down

0 comments on commit 239b1ab

Please sign in to comment.