Skip to content

Commit

Permalink
Add def multiprocessing_breakpoint function
Browse files Browse the repository at this point in the history
Using the builtin `breakpoint` function does not work with
multiprocessing, as the Process forks has different STDIN so the pdb can
not attach to it. Use this `multiprocessing_breakpoint` function
instead, when you need a breakpoint for debugging.
  • Loading branch information
kukovecz committed Jan 27, 2022
1 parent 5fb4f0d commit 1c14093
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ addopts = "--cov=unblob --cov=tests --cov-branch --cov-fail-under=90"
[tool.vulture]
paths = ["unblob/"]
exclude = ["unblob/_py/"]
ignore_names = ["breakpointhook"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
17 changes: 17 additions & 0 deletions unblob/logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import pdb
import sys
from os import getpid
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -79,3 +81,18 @@ def configure_logger(verbose: bool, extract_root: Path):
wrapper_class=structlog.make_filtering_bound_logger(log_level),
processors=processors,
)


class _MultiprocessingPdb(pdb.Pdb):
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin


def multiprocessing_breakpoint():
"""Call this in Process forks instead of the builtin `breakpoint` function for debugging with PDB."""
return _MultiprocessingPdb().set_trace(frame=sys._getframe(1))
6 changes: 5 additions & 1 deletion unblob/processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import multiprocessing
import stat
import statistics
import sys
from operator import attrgetter
from pathlib import Path
from typing import List, Optional
Expand All @@ -12,7 +13,7 @@
from .file_utils import iterate_file
from .finder import search_chunks_by_priority
from .iter_utils import pairwise
from .logging import noformat
from .logging import multiprocessing_breakpoint, noformat
from .math import shannon_entropy
from .models import ProcessingConfig, Task, UnknownChunk, ValidChunk

Expand Down Expand Up @@ -83,6 +84,9 @@ def process_file( # noqa: C901
def _process_task_queue(
task_queue: multiprocessing.JoinableQueue, config: ProcessingConfig
):
# Set custom function to breakpoint() call in the sub-processes for easier debugging
sys.breakpointhook = multiprocessing_breakpoint

while True:
logger.debug("Waiting for Task")
task = task_queue.get()
Expand Down

0 comments on commit 1c14093

Please sign in to comment.