Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Jan 31, 2025
1 parent 5ab3823 commit 4a09883
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 45 deletions.
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ sbv-primitives = { git = "https://github.com/scroll-tech/stateless-block-verifie
url = "2.5.4"

[features]
default = ["openvm"]
openvm = ["dep:sbv-utils", "dep:sbv-primitives"]

[patch.crates-io]
Expand Down
7 changes: 2 additions & 5 deletions examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl ProvingService for CloudProver {
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
async fn prove(&self, req: ProveRequest) -> ProveResponse {
async fn prove(&mut self, req: ProveRequest) -> ProveResponse {
todo!()
}
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
Expand All @@ -109,10 +109,7 @@ async fn main() -> anyhow::Result<()> {
let cfg = CloudProverConfig::from_file_and_env(args.config_file)?;
let sdk_config = cfg.sdk_config.clone();
let cloud_prover = CloudProver::new(cfg);
let prover = ProverBuilder::new(sdk_config)
.with_proving_service(Box::new(cloud_prover))
.build()
.await?;
let prover = ProverBuilder::new(sdk_config, cloud_prover).build().await?;

prover.run().await;

Expand Down
7 changes: 2 additions & 5 deletions examples/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl ProvingService for LocalProver {
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
async fn prove(&self, req: ProveRequest) -> ProveResponse {
async fn prove(&mut self, req: ProveRequest) -> ProveResponse {
todo!()
}
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
Expand All @@ -81,10 +81,7 @@ async fn main() -> anyhow::Result<()> {
let cfg = LocalProverConfig::from_file_and_env(args.config_file)?;
let sdk_config = cfg.sdk_config.clone();
let local_prover = LocalProver::new(cfg);
let prover = ProverBuilder::new(sdk_config)
.with_proving_service(Box::new(local_prover))
.build()
.await?;
let prover = ProverBuilder::new(sdk_config, local_prover).build().await?;

prover.run().await;

Expand Down
43 changes: 16 additions & 27 deletions src/prover/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use tokio::sync::RwLock;

use super::{ProofType, ProverProviderType};
use crate::{
config::Config,
Expand All @@ -12,32 +14,24 @@ use crate::{
};
use std::path::PathBuf;

pub struct ProverBuilder {
pub struct ProverBuilder<Backend: ProvingService + Send + Sync + 'static> {
cfg: Config,
proving_service: Option<Box<dyn ProvingService + Send + Sync>>,
proving_service: Backend,
}

impl ProverBuilder {
pub fn new(cfg: Config) -> Self {
impl<Backend> ProverBuilder<Backend>
where
Backend: ProvingService + Send + Sync + 'static,
{
pub fn new(cfg: Config, service: Backend) -> Self {
ProverBuilder {
cfg,
proving_service: None,
proving_service: service,
}
}

pub fn with_proving_service(
mut self,
proving_service: Box<dyn ProvingService + Send + Sync>,
) -> Self {
self.proving_service = Some(proving_service);
self
}

pub async fn build(self) -> anyhow::Result<Prover> {
if self.proving_service.is_none() {
anyhow::bail!("proving_service is not provided");
}
if self.proving_service.as_ref().unwrap().is_local() && self.cfg.prover.n_workers > 1 {
pub async fn build(self) -> anyhow::Result<Prover<Backend>> {
if self.proving_service.is_local() && self.cfg.prover.n_workers > 1 {
anyhow::bail!("cannot use multiple workers with local proving service");
}

Expand All @@ -55,17 +49,12 @@ impl ProverBuilder {
proof_types: self.cfg.prover.supported_proof_types.clone(),
circuit_version: self.cfg.prover.circuit_version.clone(),
};
let get_vk_response = self
.proving_service
.as_ref()
.unwrap()
.get_vks(get_vk_request)
.await;
let get_vk_response = self.proving_service.get_vks(get_vk_request).await;
if let Some(error) = get_vk_response.error {
anyhow::bail!("failed to get vk: {}", error);
}

let prover_provider_type = if self.proving_service.as_ref().unwrap().is_local() {
let prover_provider_type = if self.proving_service.is_local() {
ProverProviderType::Internal
} else {
ProverProviderType::External
Expand All @@ -82,7 +71,7 @@ impl ProverBuilder {

let coordinator_clients: Result<Vec<_>, _> = (0..self.cfg.prover.n_workers)
.map(|i| {
let prover_name = if self.proving_service.as_ref().unwrap().is_local() {
let prover_name = if self.proving_service.is_local() {
self.cfg.prover_name_prefix.clone()
} else {
format_cloud_prover_name(self.cfg.prover_name_prefix.clone(), i)
Expand Down Expand Up @@ -115,7 +104,7 @@ impl ProverBuilder {
circuit_version: self.cfg.prover.circuit_version,
coordinator_clients,
l2geth_client,
proving_service: self.proving_service.unwrap(),
proving_service: RwLock::new(self.proving_service),
n_workers: self.cfg.prover.n_workers,
health_listener_addr: self.cfg.health_listener_addr,
db: Db::new(&db_path)?,
Expand Down
23 changes: 17 additions & 6 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,30 @@ use ethers_providers::Middleware;
use proving_service::{ProveRequest, QueryTaskRequest, TaskStatus};
use std::net::SocketAddr;
use std::str::FromStr;
use tokio::task::JoinSet;
use tokio::time::{sleep, Duration};
use tokio::{sync::RwLock, task::JoinSet};
use tracing::{error, info, instrument};

pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};

const WORKER_SLEEP_SEC: u64 = 20;

pub struct Prover {
pub struct Prover<Backend: ProvingService + Send + Sync + 'static> {
circuit_type: CircuitType,
proof_types: Vec<ProofType>,
circuit_version: String,
coordinator_clients: Vec<CoordinatorClient>,
l2geth_client: Option<L2gethClient>,
proving_service: Box<dyn ProvingService + Send + Sync>,
proving_service: RwLock<Backend>,
n_workers: usize,
health_listener_addr: String,
db: Db,
}

impl Prover {
impl<Backend> Prover<Backend>
where
Backend: ProvingService + Send + Sync + 'static,
{
pub async fn run(self) {
assert!(self.n_workers == self.coordinator_clients.len());
if self.proof_types.contains(&ProofType::Chunk) {
Expand Down Expand Up @@ -94,7 +97,7 @@ impl Prover {
.db
.get_task(coordinator_client.key_signer.get_public_key())
{
if self.proving_service.is_local() {
if self.proving_service.read().await.is_local() {
let proving_task = self.request_proving(&coordinator_task).await?;
proving_task_id = proving_task.task_id
}
Expand Down Expand Up @@ -134,7 +137,12 @@ impl Prover {
coordinator_task: &GetTaskResponseData,
) -> anyhow::Result<proving_service::ProveResponse> {
let proving_input = self.build_proving_input(coordinator_task).await?;
let proving_task = self.proving_service.prove(proving_input).await;
let proving_task = self
.proving_service
.write()
.await
.prove(proving_input)
.await;

if let Some(error) = proving_task.error {
anyhow::bail!(
Expand Down Expand Up @@ -164,6 +172,8 @@ impl Prover {
loop {
let task = self
.proving_service
.read()
.await
.query_task(QueryTaskRequest {
task_id: proving_service_task_id.clone(),
})
Expand Down Expand Up @@ -366,6 +376,7 @@ impl Prover {
}
}

#[cfg(feature = "openvm")]
async fn build_block_witness(
&self,
hash: H256,
Expand Down
5 changes: 4 additions & 1 deletion src/prover/proving_service.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::sync::Arc;

use super::ProofType;
use async_trait::async_trait;
use tokio::sync::Mutex;

#[async_trait]
pub trait ProvingService {
fn is_local(&self) -> bool;
async fn get_vks(&self, req: GetVkRequest) -> GetVkResponse;
async fn prove(&self, req: ProveRequest) -> ProveResponse;
async fn prove(&mut self, req: ProveRequest) -> ProveResponse;
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse;
}

Expand Down

0 comments on commit 4a09883

Please sign in to comment.