Skip to content

Commit

Permalink
Allow multiple runs and handle connection communication errors (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Sep 22, 2024
1 parent a1444c3 commit 01e4e59
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 175 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_cli_misc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
run: |
pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -e .[testing,timm,diffusers,codecarbon]
pip install -e .[testing]
- name: Run tests
run: pytest -s -k "cli and not (cpu or cuda or rocm or mps)"
10 changes: 10 additions & 0 deletions optimum_benchmark/launchers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def device_isolation(self, pid: int, device_ids: Optional[str] = None):

@contextmanager
def numactl_executable(self):
self.logger.info("\t+ Warming up multiprocessing context")
dummy_process = Process(target=dummy_target, daemon=False)
dummy_process.start()
dummy_process.join()
dummy_process.close()

self.logger.info("\t+ Creating numactl wrapper executable for multiprocessing")
python_path = sys.executable
numactl_path = shutil.which("numactl")
Expand All @@ -84,3 +90,7 @@ def numactl_executable(self):
self.logger.info("\t+ Resetting default multiprocessing executable")
os.unlink(numa_executable.name)
set_executable(sys.executable)


def dummy_target() -> None:
exit(0)
2 changes: 1 addition & 1 deletion optimum_benchmark/launchers/inline/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ def __init__(self, config: InlineConfig):
super().__init__(config)

def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) -> BenchmarkReport:
self.logger.warn("The inline launcher is only recommended for debugging purposes and not for benchmarking")
self.logger.warning("The inline launcher is only recommended for debugging purposes and not for benchmarking")
report = worker(*worker_args)
return report
69 changes: 33 additions & 36 deletions optimum_benchmark/launchers/process/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from multiprocessing.connection import Connection
from typing import Any, Callable, List

import psutil

from ...benchmark.report import BenchmarkReport
from ...logging_utils import setup_logging
from ...process_utils import sync_with_child, sync_with_parent
from ..base import Launcher
from .config import ProcessConfig

Expand All @@ -21,46 +24,42 @@ def __init__(self, config: ProcessConfig):
if get_start_method(allow_none=True) != self.config.start_method:
self.logger.info(f"\t+ Setting multiprocessing start method to {self.config.start_method}")
set_start_method(self.config.start_method, force=True)
# creates the resource tracker with default executable
self.logger.info("\t+ Warming up multiprocessing context")
dummy_process = Process(target=dummy_target, daemon=False)
dummy_process.start()
dummy_process.join()
dummy_process.close()

def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) -> BenchmarkReport:
child_connection, parent_connection = Pipe()
main_process_pid = os.getpid()
isolated_process = Process(
target=target, args=(worker, worker_args, child_connection, self.logger), daemon=False
target=target, args=(worker, worker_args, child_connection, main_process_pid, self.logger), daemon=False
)

with ExitStack() as stack:
if self.config.numactl:
stack.enter_context(self.numactl_executable())

self.logger.info("\t+ Starting isolated process")
isolated_process.start()
while True:
if parent_connection.poll():
message = parent_connection.recv()
if message == "READY":
self.logger.info("\t+ Isolated process is ready")
break
else:
raise RuntimeError(f"Unexpected message from isolated process: {message}")

with ExitStack() as stack:
if isolated_process.is_alive():
sync_with_child(parent_connection)
else:
raise RuntimeError("Could not synchronize with isolated process")

if self.config.device_isolation:
stack.enter_context(self.device_isolation(isolated_process.pid))

parent_connection.send("START")
if isolated_process.is_alive():
sync_with_child(parent_connection)
else:
raise RuntimeError("Could not synchronize with isolated process")

isolated_process.join()

if isolated_process.exitcode != 0:
raise RuntimeError(f"Isolated process exited with non-zero code {isolated_process.exitcode}")

if parent_connection.poll():
response = parent_connection.recv()
else:
raise RuntimeError("Received no response from isolated process")

if "traceback" in response:
self.logger.error("\t+ Received traceback from isolated process")
Expand All @@ -81,37 +80,35 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any])
def target(
worker: Callable[..., BenchmarkReport],
worker_args: List[Any],
connection: Connection,
child_connection: Connection,
main_process_pid: int,
logger: Logger,
) -> None:
main_process = psutil.Process(main_process_pid)

if main_process.is_running():
sync_with_parent(child_connection)
else:
raise RuntimeError("Could not synchronize with main process")

log_level = os.environ.get("LOG_LEVEL", "INFO")
log_to_file = os.environ.get("LOG_TO_FILE", "1") == "1"
setup_logging(level=log_level, to_file=log_to_file, prefix="ISOLATED-PROCESS")

connection.send("READY")

while True:
if connection.poll():
message = connection.recv()
if message == "START":
logger.info("\t+ Starting benchmark in isolated process")
break
else:
raise RuntimeError(f"Unexpected message from main process: {message}")
if main_process.is_running():
sync_with_parent(child_connection)
else:
raise RuntimeError("Could not synchronize with main process")

try:
report = worker(*worker_args)
except Exception:
logger.error("\t+ Sending traceback to main process")
connection.send({"traceback": traceback.format_exc()})
child_connection.send({"traceback": traceback.format_exc()})
else:
logger.info("\t+ Sending report to main process")
connection.send({"report": report.to_dict()})
child_connection.send({"report": report.to_dict()})
finally:
logger.info("\t+ Exiting isolated process")
connection.close()
child_connection.close()
exit(0)


def dummy_target() -> None:
exit(0)
63 changes: 32 additions & 31 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from multiprocessing.connection import Connection
from typing import Any, Callable, List

import psutil
import torch.distributed
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from ...benchmark.report import BenchmarkReport
from ...logging_utils import setup_logging
from ...process_utils import sync_with_child, sync_with_parent
from ..base import Launcher
from .config import TorchrunConfig

Expand All @@ -24,11 +26,6 @@ def __init__(self, config: TorchrunConfig):
if get_start_method(allow_none=True) != self.config.start_method:
self.logger.info(f"\t+ Setting multiprocessing start method to {self.config.start_method}")
set_start_method(self.config.start_method, force=True)
self.logger.info("\t+ Warming up multiprocessing context")
# creates the resource tracker with default executable
dummy_process = Process()
dummy_process.start()
dummy_process.join()

self.launch_config = LaunchConfig(
min_nodes=self.config.min_nodes,
Expand All @@ -48,30 +45,32 @@ def __init__(self, config: TorchrunConfig):

def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) -> BenchmarkReport:
parent_connection, child_connection = Pipe()
main_process_pid = os.getpid()
isolated_process = Process(
target=target, args=(worker, worker_args, child_connection, self.launch_config, self.logger), daemon=False
target=target,
args=(worker, worker_args, child_connection, main_process_pid, self.launch_config, self.logger),
daemon=False,
)

with ExitStack() as stack:
if self.config.numactl:
stack.enter_context(self.numactl_executable())

self.logger.info("\t+ Starting isolated process")
isolated_process.start()
while True:
if parent_connection.poll():
message = parent_connection.recv()
if message == "READY":
self.logger.info("\t+ Isolated process is ready")
break
else:
raise RuntimeError(f"Unexpected message from isolated process: {message}")

with ExitStack() as stack:
if isolated_process.is_alive():
sync_with_child(parent_connection)
else:
raise RuntimeError("Could not synchronize with isolated process")

if self.config.device_isolation:
stack.enter_context(self.device_isolation(isolated_process.pid))

parent_connection.send("START")
if isolated_process.is_alive():
sync_with_child(parent_connection)
else:
raise RuntimeError("Could not synchronize with isolated process")

isolated_process.join()

if isolated_process.exitcode != 0:
Expand Down Expand Up @@ -110,37 +109,39 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any])
def target(
worker: Callable[..., BenchmarkReport],
worker_args: List[Any],
connection: Connection,
child_connection: Connection,
main_process_pid: int,
config: LaunchConfig,
logger: Logger,
):
main_process = psutil.Process(main_process_pid)

if main_process.is_running():
sync_with_parent(child_connection)
else:
raise RuntimeError("Could not synchronize with main process")

log_level = os.environ.get("LOG_LEVEL", "INFO")
log_to_file = os.environ.get("LOG_TO_FILE", "1") == "1"
setup_logging(level=log_level, to_file=log_to_file, prefix="ISOLATED-PROCESS")

connection.send("READY")

while True:
if connection.poll():
message = connection.recv()
if message == "START":
logger.info("\t+ Starting benchmark in isolated process")
break
else:
raise RuntimeError(f"Unexpected message from main process: {message}")
if main_process.is_running():
sync_with_parent(child_connection)
else:
raise RuntimeError("Could not synchronize with main process")

try:
elastic_agent_launcher = elastic_launch(config=config, entrypoint=entrypoint)
outputs = elastic_agent_launcher(worker, worker_args, logger)
except Exception:
logger.error("\t+ Sending traceback to main process")
connection.send([{"traceback": traceback.format_exc()}])
child_connection.send([{"traceback": traceback.format_exc()}])
else:
logger.info("\t+ Sending outputs to main process")
connection.send(list(outputs.values()))
child_connection.send(list(outputs.values()))
finally:
logger.info("\t+ Exiting isolated process")
connection.close()
child_connection.close()
exit(0)


Expand Down
11 changes: 11 additions & 0 deletions optimum_benchmark/process_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from multiprocessing.connection import Connection


def sync_with_parent(child_connection: Connection) -> None:
child_connection.recv()
child_connection.send(0)


def sync_with_child(parent_connection: Connection) -> None:
parent_connection.send(0)
parent_connection.recv()
4 changes: 0 additions & 4 deletions optimum_benchmark/scenarios/inference/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def run_model_loading_tracking(self, backend: Backend[BackendConfigT]):
self.report.load.memory = memory_tracker.get_max_memory()
if self.config.energy:
self.report.load.energy = energy_tracker.get_energy()
energy_tracker.stop()

## Memory tracking
def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]):
Expand Down Expand Up @@ -368,7 +367,6 @@ def run_text_generation_energy_tracking(self, backend: Backend[BackendConfigT]):
self.report.decode.efficiency = Efficiency.from_energy(
decode_energy, decode_volume, unit=TEXT_GENERATION_EFFICIENCY_UNIT
)
energy_tracker.stop()

def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running Image Diffusion energy tracking")
Expand All @@ -393,7 +391,6 @@ def run_image_diffusion_energy_tracking(self, backend: Backend[BackendConfigT]):
self.report.call.efficiency = Efficiency.from_energy(
call_energy, call_volume, unit=IMAGE_DIFFUSION_EFFICIENCY_UNIT
)
energy_tracker.stop()

def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]):
self.logger.info("\t+ Running energy tracking")
Expand All @@ -418,7 +415,6 @@ def run_inference_energy_tracking(self, backend: Backend[BackendConfigT]):
self.report.forward.efficiency = Efficiency.from_energy(
forward_energy, forward_volume, unit=INFERENCE_EFFICIENCY_UNIT
)
energy_tracker.stop()

@property
def atomic_forward_volume(self) -> int: # in samples
Expand Down
1 change: 0 additions & 1 deletion optimum_benchmark/scenarios/training/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport:
self.report.overall.efficiency = Efficiency.from_energy(
self.report.overall.energy, volume=self.overall_volume, unit=TRAIN_EFFICIENCY_UNIT
)
energy_tracker.stop()

return self.report

Expand Down
Loading

0 comments on commit 01e4e59

Please sign in to comment.