diff --git a/dspy/utils/parallelizer.py b/dspy/utils/parallelizer.py index 08dae8459..51d779f83 100644 --- a/dspy/utils/parallelizer.py +++ b/dspy/utils/parallelizer.py @@ -4,9 +4,10 @@ import logging import threading import traceback +import time import contextlib -from tqdm.contrib.logging import logging_redirect_tqdm -from concurrent.futures import ThreadPoolExecutor, as_completed + +from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED logger = logging.getLogger(__name__) @@ -19,189 +20,195 @@ def __init__( disable_progress_bar=False, provide_traceback=False, compare_results=False, + timeout=120, + straggler_limit=3, ): - """Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.""" + """ + Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1. + Handles also straggler timeouts. + """ + self.num_threads = num_threads - self.disable_progress_bar = disable_progress_bar self.max_errors = max_errors + self.disable_progress_bar = disable_progress_bar self.provide_traceback = provide_traceback self.compare_results = compare_results + self.timeout = timeout + self.straggler_limit = straggler_limit self.error_count = 0 self.error_lock = threading.Lock() self.cancel_jobs = threading.Event() def execute(self, function, data): - wrapped_function = self._wrap_function(function) - if self.num_threads == 1: - return self._execute_isolated_single_thread(wrapped_function, data) - else: - return self._execute_multi_thread(wrapped_function, data) + wrapped = self._wrap_function(function) + return self._execute_parallel(wrapped, data) - def _wrap_function(self, function): - # Wrap the function with error handling - def wrapped(item): + def _wrap_function(self, user_function): + def safe_func(item): if self.cancel_jobs.is_set(): return None try: - return function(item) + return user_function(item) except Exception as e: with self.error_lock: self.error_count += 1 - current_error_count = self.error_count - if current_error_count >= self.max_errors: - self.cancel_jobs.set() - raise e + if self.error_count >= self.max_errors: + self.cancel_jobs.set() if self.provide_traceback: - logger.error( - f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}" - ) + logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}") else: logger.error( - f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace." + f"Error for {item}: {e}. " + "Set `provide_traceback=True` for traceback." ) return None - return wrapped - - def _execute_isolated_single_thread(self, function, data): - results = [] - pbar = tqdm.tqdm( - total=len(data), - dynamic_ncols=True, - disable=self.disable_progress_bar, - file=sys.stdout - ) - - from dspy.dsp.utils.settings import thread_local_overrides - original_overrides = thread_local_overrides.overrides - for item in data: - with logging_redirect_tqdm(): - if self.cancel_jobs.is_set(): - break - - # Create an isolated context for each task by copying current overrides - # This way, even if an iteration modifies the overrides, it won't affect subsequent iterations - thread_local_overrides.overrides = original_overrides.copy() - - try: - result = function(item) - results.append(result) - finally: - thread_local_overrides.overrides = original_overrides - - if self.compare_results: - # Assumes score is the last element of the result tuple - self._update_progress( - pbar, - sum([r[-1] for r in results if r is not None]), - len([r for r in data if r is not None]), - ) - else: - self._update_progress(pbar, len(results), len(data)) + return safe_func - pbar.close() + def _execute_parallel(self, function, data): + results = [None] * len(data) + job_cancelled = "cancelled" - if self.cancel_jobs.is_set(): - logger.warning("Execution was cancelled due to errors.") - raise Exception("Execution was cancelled due to errors.") + # We resubmit at most once per item. + start_time_map = {} + start_time_lock = threading.Lock() + resubmitted = set() - return results + # This is the worker function each thread will run. + def worker(parent_overrides, submission_id, index, item): + if self.cancel_jobs.is_set(): + return index, job_cancelled + # Record actual start time + with start_time_lock: + start_time_map[submission_id] = time.time() - def _update_progress(self, pbar, nresults, ntotal): - if self.compare_results: - percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0 - pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)") - else: - pbar.set_description(f"Processed {nresults} / {ntotal} examples") + # Apply parent's thread-local overrides + from dspy.dsp.utils.settings import thread_local_overrides - pbar.update() + original = thread_local_overrides.overrides + thread_local_overrides.overrides = parent_overrides.copy() - def _execute_multi_thread(self, function, data): - results = [None] * len(data) # Pre-allocate results list to maintain order - job_cancelled = "cancelled" + try: + return index, function(item) + finally: + thread_local_overrides.overrides = original + # Handle Ctrl-C in the main thread @contextlib.contextmanager - def interrupt_handler_manager(): - """Sets the cancel_jobs event when a SIGINT is received, only in the main thread.""" - - # TODO: Is this check conducive to nested usage of ParallelExecutor? + def interrupt_manager(): if threading.current_thread() is threading.main_thread(): - default_handler = signal.getsignal(signal.SIGINT) + orig_handler = signal.getsignal(signal.SIGINT) - def interrupt_handler(sig, frame): + def handler(sig, frame): self.cancel_jobs.set() - logger.warning("Received SIGINT. Cancelling execution.") - # Re-raise the signal to allow default behavior - default_handler(sig, frame) + logger.warning("SIGINT received. Cancelling.") + orig_handler(sig, frame) - signal.signal(signal.SIGINT, interrupt_handler) + signal.signal(signal.SIGINT, handler) try: yield finally: - signal.signal(signal.SIGINT, default_handler) + signal.signal(signal.SIGINT, orig_handler) else: - # If not in the main thread, skip setting signal handlers yield - def cancellable_function(parent_overrides, index_item): - index, item = index_item - if self.cancel_jobs.is_set(): - return index, job_cancelled + executor = ThreadPoolExecutor(max_workers=self.num_threads) + try: + with interrupt_manager(): + from dspy.dsp.utils.settings import thread_local_overrides - # Create an isolated context for each task by copying parent's overrides - from dspy.dsp.utils.settings import thread_local_overrides - original_overrides = thread_local_overrides.overrides - thread_local_overrides.overrides = parent_overrides.copy() + parent_overrides = thread_local_overrides.overrides.copy() - try: - return index, function(item) - finally: - thread_local_overrides.overrides = original_overrides + futures_map = {} + futures_set = set() + submission_counter = 0 - with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager(): - from dspy.dsp.utils.settings import thread_local_overrides - parent_overrides = thread_local_overrides.overrides.copy() - - futures = {} - for pair in enumerate(data): - # Pass the parent thread's overrides to each thread - future = executor.submit(cancellable_function, parent_overrides, pair) - futures[future] = pair - - pbar = tqdm.tqdm( - total=len(data), - dynamic_ncols=True, - disable=self.disable_progress_bar, - file=sys.stdout - ) - - for future in as_completed(futures): - index, result = future.result() - - if result is job_cancelled: - continue - - results[index] = result - - if self.compare_results: - # Assumes score is the last element of the result tuple - self._update_progress( - pbar, - sum([r[-1] for r in results if r is not None]), - len([r for r in results if r is not None]), + for idx, item in enumerate(data): + f = executor.submit( + worker, parent_overrides, submission_counter, idx, item ) - else: - self._update_progress( - pbar, - len([r for r in results if r is not None]), - len(data), + futures_map[f] = (submission_counter, idx, item) + futures_set.add(f) + submission_counter += 1 + + pbar = tqdm.tqdm( + total=len(data), + dynamic_ncols=True, + disable=self.disable_progress_bar, + file=sys.stdout, + ) + + def all_done(): + return all(r is not None for r in results) + + while futures_set and not self.cancel_jobs.is_set(): + if all_done(): + break + done, not_done = wait( + futures_set, timeout=1, return_when=FIRST_COMPLETED ) - - pbar.close() + for f in done: + futures_set.remove(f) + try: + index, outcome = f.result() + except Exception: + pass + else: + if outcome != job_cancelled and results[index] is None: + results[index] = outcome + + # Update progress + if self.compare_results: + vals = [r[-1] for r in results if r is not None] + self._update_progress(pbar, sum(vals), len(vals)) + else: + self._update_progress( + pbar, + len([r for r in results if r is not None]), + len(data), + ) + + if all_done(): + break + + # Check stragglers if few remain + if 0 < self.timeout and len(not_done) <= self.straggler_limit: + now = time.time() + for f in list(not_done): + if f not in resubmitted: + sid, idx, item = futures_map[f] + with start_time_lock: + st = start_time_map.get(sid, None) + if st and (now - st) >= self.timeout: + resubmitted.add(f) + nf = executor.submit( + worker, + parent_overrides, + submission_counter, + idx, + item, + ) + futures_map[nf] = (submission_counter, idx, item) + futures_set.add(nf) + submission_counter += 1 + + pbar.close() + + finally: + # Avoid waiting on leftover tasks that no longer matter + executor.shutdown(wait=False) if self.cancel_jobs.is_set(): - logger.warning("Execution was cancelled due to errors.") - raise Exception("Execution was cancelled due to errors.") + logger.warning("Execution cancelled due to errors or interruption.") + raise Exception("Execution cancelled due to errors or interruption.") return results + + def _update_progress(self, pbar, nresults, ntotal): + if self.compare_results: + pct = round(100 * nresults / ntotal, 1) if ntotal else 0 + pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)") + else: + pbar.set_description(f"Processed {nresults} / {ntotal} examples") + pbar.update()