Skip to content

Commit

Permalink
fix mp_runner.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Feb 19, 2025
1 parent a7bb37a commit e994045
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions src/para_attn/distributed/mp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,35 @@ def start(self, args=(), kwargs=None, *, timeout=None):
process.start()
processes.append(process)

output_queue.get(timeout=timeout)
exceptions = []
for rank, exception_queue in enumerate(exception_queues):
if not exception_queue.empty():
exceptions.append((rank, exception_queue.get()))
begin_time = time.time()
while True:
if timeout is not None and time.time() - begin_time >= timeout:
raise RuntimeError("Timeout occurred")
for rank, (process, exception_queue) in enumerate(zip(processes, exception_queues)):
if process.is_alive():
if exception_queue.empty():
continue
else:
exception = exception_queue.get()
if exception is None:
exceptions.append(None)
else:
exceptions.append((rank, exception))
else:
if exception_queue.empty():
exceptions.append((rank, RuntimeError(f"Process {rank} is not alive")))
else:
exception = exception_queue.get()
if exception is None:
exceptions.append(
(rank, RuntimeError(f"Process {rank} is not alive after initialization"))
)
else:
exceptions.append((rank, exception))
if len(exceptions) == world_size or any(e is not None for e in exceptions):
break
exceptions = [e for e in exceptions if e is not None]
if exceptions:
msg = "\n".join(f"Rank {rank}: {exception}" for rank, exception in exceptions)
raise RuntimeError(f"Exceptions occurred:\n{msg}")
Expand Down Expand Up @@ -245,9 +269,10 @@ def worker(
exception = RuntimeError(f"Failed to initialize processor: {e}\n{traceback.format_exc()}")
if exception_queue is not None:
exception_queue.put(exception)
barrier.wait()
if output_queue is not None:
output_queue.put(True)
raise

if exception_queue is not None:
exception_queue.put(None)

while True:
if input_queue is None:
Expand Down

0 comments on commit e994045

Please sign in to comment.