From 1c14093f32b9db19f6d9af3bbe91a37ba16cc2de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A1nos=20Kukovecz?= Date: Mon, 24 Jan 2022 10:42:10 +0100 Subject: [PATCH] Add def multiprocessing_breakpoint function Using the builtin `breakpoint` function does not work with multiprocessing, as the Process forks has different STDIN so the pdb can not attach to it. Use this `multiprocessing_breakpoint` function instead, when you need a breakpoint for debugging. --- pyproject.toml | 1 + unblob/logging.py | 17 +++++++++++++++++ unblob/processing.py | 6 +++++- 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4c5bb5c479..bf562577e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ addopts = "--cov=unblob --cov=tests --cov-branch --cov-fail-under=90" [tool.vulture] paths = ["unblob/"] exclude = ["unblob/_py/"] +ignore_names = ["breakpointhook"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/unblob/logging.py b/unblob/logging.py index 5e539107a7..2c1cc27566 100644 --- a/unblob/logging.py +++ b/unblob/logging.py @@ -1,4 +1,6 @@ import logging +import pdb +import sys from os import getpid from pathlib import Path from typing import Any @@ -79,3 +81,18 @@ def configure_logger(verbose: bool, extract_root: Path): wrapper_class=structlog.make_filtering_bound_logger(log_level), processors=processors, ) + + +class _MultiprocessingPdb(pdb.Pdb): + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +def multiprocessing_breakpoint(): + """Call this in Process forks instead of the builtin `breakpoint` function for debugging with PDB.""" + return _MultiprocessingPdb().set_trace(frame=sys._getframe(1)) diff --git a/unblob/processing.py b/unblob/processing.py index b16691ea09..0d032471cf 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -1,6 +1,7 @@ import multiprocessing import stat import statistics +import sys from operator import attrgetter from pathlib import Path from typing import List, Optional @@ -12,7 +13,7 @@ from .file_utils import iterate_file from .finder import search_chunks_by_priority from .iter_utils import pairwise -from .logging import noformat +from .logging import multiprocessing_breakpoint, noformat from .math import shannon_entropy from .models import ProcessingConfig, Task, UnknownChunk, ValidChunk @@ -83,6 +84,9 @@ def process_file( # noqa: C901 def _process_task_queue( task_queue: multiprocessing.JoinableQueue, config: ProcessingConfig ): + # Set custom function to breakpoint() call in the sub-processes for easier debugging + sys.breakpointhook = multiprocessing_breakpoint + while True: logger.debug("Waiting for Task") task = task_queue.get()