Skip to content

Commit

Permalink
forward process output using a queue
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Nov 22, 2023
1 parent 60be0c2 commit 252f3e7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
7 changes: 5 additions & 2 deletions optimum_benchmark/launchers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC
from dataclasses import dataclass
from logging import getLogger
from typing import Callable, ClassVar, Generic, TypeVar
from typing import TYPE_CHECKING, Callable, ClassVar, Generic, TypeVar

if TYPE_CHECKING:
from ..benchmarks.base import Benchmark

LOGGER = getLogger("launcher")

Expand All @@ -27,5 +30,5 @@ def configure(self, config: "LauncherConfigT") -> None:
LOGGER.info(f"Configuring {self.NAME} launcher")
self.config = config

def launch(self, worker: Callable, **worker_kwargs) -> None:
def launch(self, worker: Callable, *worker_args) -> "Benchmark":
raise NotImplementedError("Launcher must implement launch method")
43 changes: 29 additions & 14 deletions optimum_benchmark/launchers/process/launcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import logging.config
import multiprocessing as mp
from logging import getLogger
from multiprocessing import Process
from typing import Callable
from multiprocessing import Process, Queue
from typing import TYPE_CHECKING, Callable

from omegaconf import OmegaConf

from ..base import Launcher
from .config import ProcessConfig

if TYPE_CHECKING:
from ...benchmarks.base import Benchmark

LOGGER = getLogger("process")

# Create the Queue
QUEUE = Queue()


class ProcessLauncher(Launcher[ProcessConfig]):
NAME = "process"
Expand All @@ -21,26 +27,31 @@ def __init__(self) -> None:
def configure(self, config: ProcessConfig) -> None:
super().configure(config)

def launch(self, worker: Callable, *worker_args) -> None:
def launch(self, worker: Callable, *worker_args) -> "Benchmark":
# Set the multiprocessing start method if not already set
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method(self.config.start_method)

# Execute in a separate process
p = Process(
# Create the process
process = Process(
target=target,
args=(worker, *worker_args),
daemon=True,
)
p.start()
benchmark = p.join()

# Exit with the same exit code as the child process
if p.exitcode != 0:
LOGGER.error(f"Child process exited with code {p.exitcode}")
exit(p.exitcode)
else:
return benchmark
# Start the process
process.start()

# Wait for the process to finish
process.join()

if process.exitcode != 0:
raise RuntimeError(f"Process exited with code {process.exitcode}")

# Get the benchmark from the queue
benchmark = QUEUE.get()

return benchmark


def target(fn, *args):
Expand All @@ -50,4 +61,8 @@ def target(fn, *args):
hydra_conf = OmegaConf.load(".hydra/hydra.yaml")
logging.config.dictConfig(OmegaConf.to_container(hydra_conf.hydra.job_logging, resolve=True))

return fn(*args)
# Run the function
result = fn(*args)

# Put the result in the queue
QUEUE.put(result)
8 changes: 6 additions & 2 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging.config
import os
from logging import getLogger
from typing import Callable
from typing import TYPE_CHECKING, Callable

from omegaconf import OmegaConf
from torch.distributed.elastic.multiprocessing import Std
Expand All @@ -11,6 +11,10 @@
from ..base import Launcher
from .config import TorchrunConfig

if TYPE_CHECKING:
from ...benchmarks.base import Benchmark


LOGGER = getLogger("torchrun")


Expand All @@ -25,7 +29,7 @@ def configure(self, config: "TorchrunConfig") -> None:

LOGGER.info(f"Running {self.config.nproc_per_node} processes per node")

def launch(self, worker: Callable, *worker_args) -> None:
def launch(self, worker: Callable, *worker_args) -> "Benchmark":
launch_config = LaunchConfig(
min_nodes=self.config.min_nodes,
max_nodes=self.config.max_nodes,
Expand Down

0 comments on commit 252f3e7

Please sign in to comment.