Skip to content

Commit

Permalink
Merge branch 'DOR-927_low_latency_timeout_patch' into 'master'
Browse files Browse the repository at this point in the history
DOR-927 Low latency timeout patch

See merge request machine-learning/dorado!1250
  • Loading branch information
kdolan1973 committed Oct 24, 2024
2 parents 2b053b6 + 56c47b1 commit 7bca166
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions dorado/read_pipeline/BasecallerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ void BasecallerNode::basecall_worker_thread(int worker_id) {
const size_t chunk_size = m_model_runners[worker_id]->chunk_size();
const bool is_low_latency = m_model_runners[worker_id]->is_low_latency();
const int chunk_queue_idx = worker_id % int(m_chunk_in_queues.size());
auto &worker_chunks = m_batched_chunks[worker_id];

const int batch_timeout_ms = (is_low_latency && m_low_latency_batch_timeout_ms != 0)
? m_low_latency_batch_timeout_ms
Expand All @@ -278,7 +279,7 @@ void BasecallerNode::basecall_worker_thread(int worker_id) {

if (pop_status == utils::AsyncQueueStatus::Timeout) {
// try_pop_until timed out without getting a new chunk.
if (!m_batched_chunks[worker_id].empty()) {
if (!worker_chunks.empty()) {
// get scores for whatever chunks are available.
basecall_current_batch(worker_id);
}
Expand All @@ -289,7 +290,7 @@ void BasecallerNode::basecall_worker_thread(int worker_id) {

// There's chunks to get_scores, so let's add them to our input tensor
// FIXME -- it should not be possible to for this condition to be untrue.
if (m_batched_chunks[worker_id].size() != batch_size) {
if (worker_chunks.size() != batch_size) {
// Copy the chunk into the input tensor
auto &source_read = chunk->owning_read->read;

Expand All @@ -312,24 +313,26 @@ void BasecallerNode::basecall_worker_thread(int worker_id) {
}

// Insert the chunk in the input tensor
m_model_runners[worker_id]->accept_chunk(
static_cast<int>(m_batched_chunks[worker_id].size()), input_slice);
m_model_runners[worker_id]->accept_chunk(static_cast<int>(worker_chunks.size()),
input_slice);

m_batched_chunks[worker_id].push_back(std::move(chunk));
worker_chunks.push_back(std::move(chunk));

if (m_batched_chunks.size() == 1 || !measure_timeout_from_first_chunk) {
if (worker_chunks.size() == 1 || !measure_timeout_from_first_chunk) {
// If we're measuring the timeout from the first chunk, we only reset the timer
// if this is the first chunk to be added to the buffer.
chunk_reserve_time = std::chrono::system_clock::now();
}
}

if (m_batched_chunks[worker_id].size() == batch_size) {
if (worker_chunks.size() == batch_size) {
// Input tensor is full, let's get_scores.
basecall_current_batch(worker_id);
chunk_reserve_time = std::chrono::system_clock::now();
}
}

if (!m_batched_chunks[worker_id].empty()) {
if (!worker_chunks.empty()) {
basecall_current_batch(worker_id);
}

Expand Down

0 comments on commit 7bca166

Please sign in to comment.