Skip to content

Commit

Permalink
Merge branch 'smalton/INSTX-4095-client-stalls' into 'master'
Browse files Browse the repository at this point in the history
INSTX-4095: Client stalls

See merge request machine-learning/dorado!1052
  • Loading branch information
malton-ont committed Jun 5, 2024
2 parents 3c631dc + d8953f7 commit 8e000e2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
7 changes: 6 additions & 1 deletion dorado/read_pipeline/BasecallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,13 @@ void BasecallerNode::basecall_current_batch(int worker_id) {
m_processed_chunks.try_push(std::move(complete_chunk));
}

if (m_batched_chunks[worker_id].size() == model_runner->batch_size()) {
++m_num_batches_called;
} else {
++m_num_partial_batches_called;
}

m_batched_chunks[worker_id].clear();
++m_num_batches_called;
}

void BasecallerNode::working_reads_manager() {
Expand Down
15 changes: 13 additions & 2 deletions dorado/read_pipeline/MessageSink.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#pragma once

#include "read_pipeline/flush_options.h"
#include "read_pipeline/messages.h"
#include "ClientInfo.h"
#include "flush_options.h"
#include "messages.h"
#include "utils/AsyncQueue.h"
#include "utils/stats.h"

Expand Down Expand Up @@ -50,6 +51,8 @@ class MessageSink {
virtual void restart() = 0;

protected:
virtual bool forward_on_disconnected() const { return true; }

// Terminates waits on the input queue.
void terminate_input_queue() { m_work_queue.terminate(); }

Expand All @@ -75,6 +78,14 @@ class MessageSink {
// If terminating, returns false.
bool get_input_message(Message& message) {
auto status = m_work_queue.try_pop(message);
if (!m_sinks.empty() && forward_on_disconnected()) {
while (status == utils::AsyncQueueStatus::Success && is_read_message(message) &&
get_read_common_data(message).client_info &&
get_read_common_data(message).client_info->is_disconnected()) {
send_message_to_sink(0, std::move(message));
status = m_work_queue.try_pop(message);
}
}
return status == utils::AsyncQueueStatus::Success;
}

Expand Down
7 changes: 6 additions & 1 deletion dorado/read_pipeline/ModBaseCallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,13 @@ void ModBaseCallerNode::call_current_batch(
m_processed_chunks.try_push(std::move(chunk));
}

if (batched_chunks.size() == m_batch_size) {
++m_num_batches_called;
} else {
++m_num_partial_batches_called;
}

batched_chunks.clear();
++m_num_batches_called;
}

void ModBaseCallerNode::output_worker_thread() {
Expand Down
5 changes: 5 additions & 0 deletions dorado/read_pipeline/SubreadTaggerNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class SubreadTaggerNode : public MessageSink {
void terminate(const FlushOptions &) override { terminate_impl(); }
void restart() override { start_threads(); }

protected:
// Ensure this node processes reads from disconnected clients, otherwise we may not properly
// calculate the subtag ids for duplex reads that are in-flight when a disconnect occurs
bool forward_on_disconnected() const override { return false; }

private:
void start_threads();
void terminate_impl();
Expand Down

0 comments on commit 8e000e2

Please sign in to comment.