Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Add prover jobs as one multi-insert for Witness Generator #3587

Merged
merged 12 commits into from
Feb 11, 2025
Merged
1 change: 1 addition & 0 deletions prover/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions prover/crates/lib/prover_dal/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ sqlx = { workspace = true, features = [
"migrate",
"ipnetwork",
] }
tokio = { workspace = true, features = ["rt"] }
118 changes: 103 additions & 15 deletions prover/crates/lib/prover_dal/src/fri_prover_dal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{
time::{Duration, Instant},
};

use sqlx::QueryBuilder;
use zksync_basic_types::{
basic_fri_types::{
AggregationRound, CircuitIdRoundTuple, CircuitProverStatsEntry,
Expand All @@ -29,6 +30,12 @@ pub struct FriProverDal<'a, 'c> {
}

impl FriProverDal<'_, '_> {
// Postgres has a limit of 65535 push_bind parameters per query.
// We need to split the insert into chunks to avoid hitting this limit.
// A single row in insert_prover_jobs push_binds 10 parameters, therefore
// the limit is 65k / 10 ~ 6500 jobs chunk.
const INSERT_JOBS_CHUNK_SIZE: usize = 6500;

pub async fn insert_prover_jobs(
&mut self,
l1_batch_number: L1BatchNumber,
Expand All @@ -37,23 +44,66 @@ impl FriProverDal<'_, '_> {
depth: u16,
protocol_version_id: ProtocolSemanticVersion,
) {
let latency = MethodLatency::new("save_fri_prover_jobs");
for (sequence_number, (circuit_id, circuit_blob_url)) in
circuit_ids_and_urls.iter().enumerate()
let _latency = MethodLatency::new("save_fri_prover_jobs");
if circuit_ids_and_urls.is_empty() {
return;
}

for (chunk_index, chunk) in circuit_ids_and_urls
.chunks(Self::INSERT_JOBS_CHUNK_SIZE)
.enumerate()
{
self.insert_prover_job(
l1_batch_number,
*circuit_id,
depth,
sequence_number,
aggregation_round,
circuit_blob_url,
false,
protocol_version_id,
)
.await;
// Build multi-row INSERT for the current chunk
let mut query_builder = QueryBuilder::new(
r#"
INSERT INTO prover_jobs_fri (
l1_batch_number,
circuit_id,
circuit_blob_url,
aggregation_round,
sequence_number,
depth,
is_node_final_proof,
protocol_version,
status,
created_at,
updated_at,
protocol_version_patch
)
"#,
);

query_builder.push_values(
chunk.iter().enumerate(),
|mut row, (i, (circuit_id, circuit_blob_url))| {
row.push_bind(l1_batch_number.0 as i64)
.push_bind(*circuit_id as i16)
.push_bind(circuit_blob_url)
.push_bind(aggregation_round as i64)
.push_bind((chunk_index * Self::INSERT_JOBS_CHUNK_SIZE + i) as i64) // sequence_number
.push_bind(depth as i32)
.push_bind(false) // is_node_final_proof
.push_bind(protocol_version_id.minor as i32)
.push_bind("queued") // status
.push("NOW()") // created_at
.push("NOW()") // updated_at
.push_bind(protocol_version_id.patch.0 as i32);
},
);

// Add the ON CONFLICT clause
query_builder.push(
r#"
ON CONFLICT (l1_batch_number, aggregation_round, circuit_id, depth, sequence_number)
DO UPDATE
SET updated_at = NOW()
"#,
);

// Execute the built query
let query = query_builder.build();
query.execute(self.storage.conn()).await.unwrap();
}
drop(latency);
}

/// Retrieves the next prover job to be proven. Called by WVGs.
Expand Down Expand Up @@ -981,3 +1031,41 @@ impl FriProverDal<'_, '_> {
.collect::<_>()
}
}

#[cfg(test)]
mod tests {
use zksync_basic_types::protocol_version::L1VerifierConfig;
use zksync_db_connection::connection_pool::ConnectionPool;

use super::*;
use crate::ProverDal;

fn mock_circuit_ids_and_urls(num_circuits: usize) -> Vec<(u8, String)> {
(0..num_circuits)
.map(|i| (i as u8, format!("circuit{}", i)))
.collect()
}

#[tokio::test]
async fn test_insert_prover_jobs() {
let pool = ConnectionPool::<Prover>::prover_test_pool().await;
let mut conn = pool.connection().await.unwrap();

conn.fri_protocol_versions_dal()
.save_prover_protocol_version(
ProtocolSemanticVersion::default(),
L1VerifierConfig::default(),
)
.await;

conn.fri_prover_jobs_dal()
.insert_prover_jobs(
L1BatchNumber(1),
mock_circuit_ids_and_urls(10000),
AggregationRound::Scheduler,
1,
ProtocolSemanticVersion::default(),
)
.await;
}
}
Loading