From 252f3e74a59a8e7efa09d13a51474ac1f73fd78c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 22 Nov 2023 02:50:04 +0100 Subject: [PATCH] forward process output using a queue --- optimum_benchmark/launchers/base.py | 7 ++- .../launchers/process/launcher.py | 43 +++++++++++++------ .../launchers/torchrun/launcher.py | 8 +++- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/optimum_benchmark/launchers/base.py b/optimum_benchmark/launchers/base.py index 30c13dd3..d738ecef 100644 --- a/optimum_benchmark/launchers/base.py +++ b/optimum_benchmark/launchers/base.py @@ -1,7 +1,10 @@ from abc import ABC from dataclasses import dataclass from logging import getLogger -from typing import Callable, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, Callable, ClassVar, Generic, TypeVar + +if TYPE_CHECKING: + from ..benchmarks.base import Benchmark LOGGER = getLogger("launcher") @@ -27,5 +30,5 @@ def configure(self, config: "LauncherConfigT") -> None: LOGGER.info(f"Configuring {self.NAME} launcher") self.config = config - def launch(self, worker: Callable, **worker_kwargs) -> None: + def launch(self, worker: Callable, *worker_args) -> "Benchmark": raise NotImplementedError("Launcher must implement launch method") diff --git a/optimum_benchmark/launchers/process/launcher.py b/optimum_benchmark/launchers/process/launcher.py index afd42120..230cac0c 100644 --- a/optimum_benchmark/launchers/process/launcher.py +++ b/optimum_benchmark/launchers/process/launcher.py @@ -1,16 +1,22 @@ import logging.config import multiprocessing as mp from logging import getLogger -from multiprocessing import Process -from typing import Callable +from multiprocessing import Process, Queue +from typing import TYPE_CHECKING, Callable from omegaconf import OmegaConf from ..base import Launcher from .config import ProcessConfig +if TYPE_CHECKING: + from ...benchmarks.base import Benchmark + LOGGER = getLogger("process") +# Create the Queue +QUEUE = Queue() + class ProcessLauncher(Launcher[ProcessConfig]): NAME = "process" @@ -21,26 +27,31 @@ def __init__(self) -> None: def configure(self, config: ProcessConfig) -> None: super().configure(config) - def launch(self, worker: Callable, *worker_args) -> None: + def launch(self, worker: Callable, *worker_args) -> "Benchmark": # Set the multiprocessing start method if not already set if mp.get_start_method(allow_none=True) is None: mp.set_start_method(self.config.start_method) - # Execute in a separate process - p = Process( + # Create the process + process = Process( target=target, args=(worker, *worker_args), daemon=True, ) - p.start() - benchmark = p.join() - # Exit with the same exit code as the child process - if p.exitcode != 0: - LOGGER.error(f"Child process exited with code {p.exitcode}") - exit(p.exitcode) - else: - return benchmark + # Start the process + process.start() + + # Wait for the process to finish + process.join() + + if process.exitcode != 0: + raise RuntimeError(f"Process exited with code {process.exitcode}") + + # Get the benchmark from the queue + benchmark = QUEUE.get() + + return benchmark def target(fn, *args): @@ -50,4 +61,8 @@ def target(fn, *args): hydra_conf = OmegaConf.load(".hydra/hydra.yaml") logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True)) - return fn(*args) + # Run the function + result = fn(*args) + + # Put the result in the queue + QUEUE.put(result) diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 0025acc1..0ddf7035 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -1,7 +1,7 @@ import logging.config import os from logging import getLogger -from typing import Callable +from typing import TYPE_CHECKING, Callable from omegaconf import OmegaConf from torch.distributed.elastic.multiprocessing import Std @@ -11,6 +11,10 @@ from ..base import Launcher from .config import TorchrunConfig +if TYPE_CHECKING: + from ...benchmarks.base import Benchmark + + LOGGER = getLogger("torchrun") @@ -25,7 +29,7 @@ def configure(self, config: "TorchrunConfig") -> None: LOGGER.info(f"Running {self.config.nproc_per_node} processes per node") - def launch(self, worker: Callable, *worker_args) -> None: + def launch(self, worker: Callable, *worker_args) -> "Benchmark": launch_config = LaunchConfig( min_nodes=self.config.min_nodes, max_nodes=self.config.max_nodes,