Skip to content

Commit

Permalink
Add timeout-based straggler resubmission in ParallelExecutor (stanfor…
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat authored Mar 6, 2025
1 parent 5660da7 commit a2fc6a1
Showing 1 changed file with 144 additions and 137 deletions.
281 changes: 144 additions & 137 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import logging
import threading
import traceback
import time
import contextlib
from tqdm.contrib.logging import logging_redirect_tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

from concurrent.futures import ThreadPoolExecutor, wait, FIRST_COMPLETED

logger = logging.getLogger(__name__)

Expand All @@ -19,189 +20,195 @@ def __init__(
disable_progress_bar=False,
provide_traceback=False,
compare_results=False,
timeout=120,
straggler_limit=3,
):
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
"""
Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1.
Handles also straggler timeouts.
"""

self.num_threads = num_threads
self.disable_progress_bar = disable_progress_bar
self.max_errors = max_errors
self.disable_progress_bar = disable_progress_bar
self.provide_traceback = provide_traceback
self.compare_results = compare_results
self.timeout = timeout
self.straggler_limit = straggler_limit

self.error_count = 0
self.error_lock = threading.Lock()
self.cancel_jobs = threading.Event()

def execute(self, function, data):
wrapped_function = self._wrap_function(function)
if self.num_threads == 1:
return self._execute_isolated_single_thread(wrapped_function, data)
else:
return self._execute_multi_thread(wrapped_function, data)
wrapped = self._wrap_function(function)
return self._execute_parallel(wrapped, data)

def _wrap_function(self, function):
# Wrap the function with error handling
def wrapped(item):
def _wrap_function(self, user_function):
def safe_func(item):
if self.cancel_jobs.is_set():
return None
try:
return function(item)
return user_function(item)
except Exception as e:
with self.error_lock:
self.error_count += 1
current_error_count = self.error_count
if current_error_count >= self.max_errors:
self.cancel_jobs.set()
raise e
if self.error_count >= self.max_errors:
self.cancel_jobs.set()
if self.provide_traceback:
logger.error(
f"Error processing item {item}: {e}\nStack trace:\n{traceback.format_exc()}"
)
logger.error(f"Error for {item}: {e}\n{traceback.format_exc()}")
else:
logger.error(
f"Error processing item {item}: {e}. Set `provide_traceback=True` to see the stack trace."
f"Error for {item}: {e}. "
"Set `provide_traceback=True` for traceback."
)
return None
return wrapped

def _execute_isolated_single_thread(self, function, data):
results = []
pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)

from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides

for item in data:
with logging_redirect_tqdm():
if self.cancel_jobs.is_set():
break

# Create an isolated context for each task by copying current overrides
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations
thread_local_overrides.overrides = original_overrides.copy()

try:
result = function(item)
results.append(result)
finally:
thread_local_overrides.overrides = original_overrides

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in data if r is not None]),
)
else:
self._update_progress(pbar, len(results), len(data))
return safe_func

pbar.close()
def _execute_parallel(self, function, data):
results = [None] * len(data)
job_cancelled = "cancelled"

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
# We resubmit at most once per item.
start_time_map = {}
start_time_lock = threading.Lock()
resubmitted = set()

return results
# This is the worker function each thread will run.
def worker(parent_overrides, submission_id, index, item):
if self.cancel_jobs.is_set():
return index, job_cancelled
# Record actual start time
with start_time_lock:
start_time_map[submission_id] = time.time()

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
percentage = round(100 * nresults / ntotal, 1) if ntotal > 0 else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({percentage}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
# Apply parent's thread-local overrides
from dspy.dsp.utils.settings import thread_local_overrides

pbar.update()
original = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()

def _execute_multi_thread(self, function, data):
results = [None] * len(data) # Pre-allocate results list to maintain order
job_cancelled = "cancelled"
try:
return index, function(item)
finally:
thread_local_overrides.overrides = original

# Handle Ctrl-C in the main thread
@contextlib.contextmanager
def interrupt_handler_manager():
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread."""

# TODO: Is this check conducive to nested usage of ParallelExecutor?
def interrupt_manager():
if threading.current_thread() is threading.main_thread():
default_handler = signal.getsignal(signal.SIGINT)
orig_handler = signal.getsignal(signal.SIGINT)

def interrupt_handler(sig, frame):
def handler(sig, frame):
self.cancel_jobs.set()
logger.warning("Received SIGINT. Cancelling execution.")
# Re-raise the signal to allow default behavior
default_handler(sig, frame)
logger.warning("SIGINT received. Cancelling.")
orig_handler(sig, frame)

signal.signal(signal.SIGINT, interrupt_handler)
signal.signal(signal.SIGINT, handler)
try:
yield
finally:
signal.signal(signal.SIGINT, default_handler)
signal.signal(signal.SIGINT, orig_handler)
else:
# If not in the main thread, skip setting signal handlers
yield

def cancellable_function(parent_overrides, index_item):
index, item = index_item
if self.cancel_jobs.is_set():
return index, job_cancelled
executor = ThreadPoolExecutor(max_workers=self.num_threads)
try:
with interrupt_manager():
from dspy.dsp.utils.settings import thread_local_overrides

# Create an isolated context for each task by copying parent's overrides
from dspy.dsp.utils.settings import thread_local_overrides
original_overrides = thread_local_overrides.overrides
thread_local_overrides.overrides = parent_overrides.copy()
parent_overrides = thread_local_overrides.overrides.copy()

try:
return index, function(item)
finally:
thread_local_overrides.overrides = original_overrides
futures_map = {}
futures_set = set()
submission_counter = 0

with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
from dspy.dsp.utils.settings import thread_local_overrides
parent_overrides = thread_local_overrides.overrides.copy()

futures = {}
for pair in enumerate(data):
# Pass the parent thread's overrides to each thread
future = executor.submit(cancellable_function, parent_overrides, pair)
futures[future] = pair

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout
)

for future in as_completed(futures):
index, result = future.result()

if result is job_cancelled:
continue

results[index] = result

if self.compare_results:
# Assumes score is the last element of the result tuple
self._update_progress(
pbar,
sum([r[-1] for r in results if r is not None]),
len([r for r in results if r is not None]),
for idx, item in enumerate(data):
f = executor.submit(
worker, parent_overrides, submission_counter, idx, item
)
else:
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
futures_map[f] = (submission_counter, idx, item)
futures_set.add(f)
submission_counter += 1

pbar = tqdm.tqdm(
total=len(data),
dynamic_ncols=True,
disable=self.disable_progress_bar,
file=sys.stdout,
)

def all_done():
return all(r is not None for r in results)

while futures_set and not self.cancel_jobs.is_set():
if all_done():
break
done, not_done = wait(
futures_set, timeout=1, return_when=FIRST_COMPLETED
)

pbar.close()
for f in done:
futures_set.remove(f)
try:
index, outcome = f.result()
except Exception:
pass
else:
if outcome != job_cancelled and results[index] is None:
results[index] = outcome

# Update progress
if self.compare_results:
vals = [r[-1] for r in results if r is not None]
self._update_progress(pbar, sum(vals), len(vals))
else:
self._update_progress(
pbar,
len([r for r in results if r is not None]),
len(data),
)

if all_done():
break

# Check stragglers if few remain
if 0 < self.timeout and len(not_done) <= self.straggler_limit:
now = time.time()
for f in list(not_done):
if f not in resubmitted:
sid, idx, item = futures_map[f]
with start_time_lock:
st = start_time_map.get(sid, None)
if st and (now - st) >= self.timeout:
resubmitted.add(f)
nf = executor.submit(
worker,
parent_overrides,
submission_counter,
idx,
item,
)
futures_map[nf] = (submission_counter, idx, item)
futures_set.add(nf)
submission_counter += 1

pbar.close()

finally:
# Avoid waiting on leftover tasks that no longer matter
executor.shutdown(wait=False)

if self.cancel_jobs.is_set():
logger.warning("Execution was cancelled due to errors.")
raise Exception("Execution was cancelled due to errors.")
logger.warning("Execution cancelled due to errors or interruption.")
raise Exception("Execution cancelled due to errors or interruption.")

return results

def _update_progress(self, pbar, nresults, ntotal):
if self.compare_results:
pct = round(100 * nresults / ntotal, 1) if ntotal else 0
pbar.set_description(f"Average Metric: {nresults:.2f} / {ntotal} ({pct}%)")
else:
pbar.set_description(f"Processed {nresults} / {ntotal} examples")
pbar.update()

0 comments on commit a2fc6a1

Please sign in to comment.