Skip to content

Commit

Permalink
Introduce async Prover (#280)
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer authored and irakliyk committed May 9, 2024
1 parent f8eaec4 commit 652a407
Show file tree
Hide file tree
Showing 12 changed files with 411 additions and 66 deletions.
21 changes: 18 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,38 @@ jobs:
run: |
rustup update --no-self-update nightly
rustup +nightly component add clippy
cargo +nightly clippy --workspace --all-targets --all-features -- -D clippy::all -D warnings
# Specifically, don't enable `async` feature in winter-prover crate
cargo +nightly clippy --workspace --all-targets --features std,concurrent -- -D clippy::all -D warnings
- name: Rustfmt
run: |
rustup +nightly component add rustfmt
cargo +nightly fmt --all --check
# Note: the examples won't compile when the prover is built with the `async` feature, since they're designed to be sync only.
# Hence, we avoid this scenario explicitly.
check:
name: Check all features and all targets against the MSRV
name: Check all features and all targets against the MSRV, except for winter-prover and winterfell
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Perform checks
run: |
rustup update --no-self-update stable
cargo +stable install cargo-hack --locked
RUSTFLAGS=-Dwarnings cargo +stable hack --no-private --feature-powerset --keep-going check --rust-version --workspace --all-targets --verbose
RUSTFLAGS=-Dwarnings cargo +stable hack --no-private --feature-powerset --keep-going check --rust-version --verbose --all-targets --workspace --exclude winter-prover --exclude winterfell
# Check prover and winterfell alone; specifically, that the `async` feature builds correctly
check-prover:
name: Check prover and winterfell packages
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@main
- name: Perform check
run: |
rustup update --no-self-update stable
cargo +stable install cargo-hack --locked
RUSTFLAGS=-Dwarnings cargo +stable hack --package winter-prover --package winterfell --no-private --feature-powerset --keep-going check --rust-version --verbose
test:
name: Test Rust ${{matrix.toolchain}} on ${{matrix.os}}
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]
members = [
"utils/core",
"utils/maybe_async",
"utils/rand",
"math",
"crypto",
Expand Down
4 changes: 2 additions & 2 deletions air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ pub trait Air: Send + Sync {

/// A type defining shape of public inputs for the computation described by this protocol.
/// This could be any type as long as it can be serialized into a sequence of field elements.
type PublicInputs: ToElements<Self::BaseField>;
type PublicInputs: ToElements<Self::BaseField> + Send;

/// An GKR proof object. If not needed, set to `()`.
type GkrProof: Serializable + Deserializable;
type GkrProof: Serializable + Deserializable + Send;

/// A verifier for verifying GKR proofs. If not needed, set to `()`.
type GkrVerifier: GkrVerifier<GkrProof = Self::GkrProof>;
Expand Down
5 changes: 4 additions & 1 deletion prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@ name = "lagrange_kernel"
harness = false

[features]
async = ["async-trait", "maybe_async/async"]
concurrent = ["crypto/concurrent", "math/concurrent", "fri/concurrent", "utils/concurrent", "std"]
default = ["std"]
std = ["air/std", "crypto/std", "fri/std", "math/std", "utils/std"]

[dependencies]
air = { version = "0.8", path = "../air", package = "winter-air", default-features = false }
async-trait = { version = "0.1.80", optional = true }
crypto = { version = "0.8", path = "../crypto", package = "winter-crypto", default-features = false }
fri = { version = "0.8", path = '../fri', package = "winter-fri", default-features = false }
math = { version = "0.8", path = "../math", package = "winter-math", default-features = false }
tracing = { version = "0.1", default-features = false }
maybe_async = { version = "0.8", path = "../utils/maybe_async", package = "winter-maybe-async"}
tracing = { version = "0.1", default-features = false, features = ["attributes"]}
utils = { version = "0.8", path = "../utils/core", package = "winter-utils", default-features = false }

[dev-dependencies]
Expand Down
8 changes: 8 additions & 0 deletions prover/src/constraints/evaluator/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use air::{
Air, AuxRandElements, ConstraintCompositionCoefficients, EvaluationFrame, TransitionConstraints,
};
use math::FieldElement;
use tracing::instrument;
use utils::iter_mut;

#[cfg(feature = "concurrent")]
Expand Down Expand Up @@ -51,6 +52,13 @@ where
{
type Air = A;

#[instrument(
skip_all,
name = "evaluate_constraints",
fields(
ce_domain_size = %domain.ce_domain_size()
)
)]
fn evaluate<T: TraceLde<E>>(
self,
trace: &T,
Expand Down
150 changes: 90 additions & 60 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@
#[macro_use]
extern crate alloc;

#[cfg(feature = "async")]
use alloc::boxed::Box;

use air::AuxRandElements;
pub use air::{
proof, proof::Proof, Air, AirContext, Assertion, BoundaryConstraint, BoundaryConstraintGroup,
ConstraintCompositionCoefficients, ConstraintDivisor, DeepCompositionCoefficients,
EvaluationFrame, FieldExtension, LagrangeKernelRandElements, ProofOptions, TraceInfo,
TransitionConstraintDegree,
};
use tracing::{event, info_span, Level};
use maybe_async::maybe_async;
use tracing::{event, info_span, instrument, Level};
pub use utils::{
iterators, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
SliceReader,
Expand Down Expand Up @@ -128,6 +132,7 @@ pub type ProverGkrProof<P> = <<P as Prover>::Air as Air>::GkrProof;
/// of these types are provided with the prover). For example, providing custom implementations
/// of [TraceLde] and/or [ConstraintEvaluator] can be beneficial when some steps of proof
/// generation can be delegated to non-CPU hardware (e.g., GPUs).
#[maybe_async]
pub trait Prover {
/// Base field for the computation described by this prover.
type BaseField: StarkField + ExtensibleField<2> + ExtensibleField<3>;
Expand All @@ -136,16 +141,16 @@ pub trait Prover {
type Air: Air<BaseField = Self::BaseField>;

/// Execution trace of the computation described by this prover.
type Trace: Trace<BaseField = Self::BaseField>;
type Trace: Trace<BaseField = Self::BaseField> + Send + Sync;

/// Hash function to be used.
type HashFn: ElementHasher<BaseField = Self::BaseField>;

/// PRNG to be used for generating random field elements.
type RandomCoin: RandomCoin<BaseField = Self::BaseField, Hasher = Self::HashFn>;
type RandomCoin: RandomCoin<BaseField = Self::BaseField, Hasher = Self::HashFn> + Send + Sync;

/// Trace low-degree extension for building the LDEs of trace segments and their commitments.
type TraceLde<E>: TraceLde<E, HashFn = Self::HashFn>
type TraceLde<E>: TraceLde<E, HashFn = Self::HashFn> + Send + Sync
where
E: FieldElement<BaseField = Self::BaseField>;

Expand Down Expand Up @@ -175,7 +180,7 @@ pub trait Prover {
///
/// Returns a tuple containing a [TracePolyTable] with the trace polynomials for the main trace
/// and a new [TraceLde] instance from which the LDE and trace commitments can be obtained.
fn new_trace_lde<E>(
async fn new_trace_lde<E>(
&self,
trace_info: &TraceInfo,
main_trace: &ColMatrix<Self::BaseField>,
Expand All @@ -186,7 +191,7 @@ pub trait Prover {

/// Returns a new constraint evaluator which can be used to evaluate transition and boundary
/// constraints over the extended execution trace.
fn new_evaluator<'a, E>(
async fn new_evaluator<'a, E>(
&self,
air: &'a Self::Air,
aux_rand_elements: Option<AuxRandElements<E>>,
Expand All @@ -200,7 +205,7 @@ pub trait Prover {

/// Builds the GKR proof. If the [`Air`] doesn't use a GKR proof, leave unimplemented.
#[allow(unused_variables)]
fn generate_gkr_proof<E>(
async fn generate_gkr_proof<E>(
&self,
main_trace: &Self::Trace,
public_coin: &mut Self::RandomCoin,
Expand All @@ -213,7 +218,7 @@ pub trait Prover {

/// Builds and returns the auxiliary trace.
#[allow(unused_variables)]
fn build_aux_trace<E>(
async fn build_aux_trace<E>(
&self,
main_trace: &Self::Trace,
aux_rand_elements: &AuxRandElements<E>,
Expand All @@ -232,23 +237,27 @@ pub trait Prover {
/// public inputs. It may also contain a GKR proof, further documented in [`Proof`].
/// Public inputs must match the value returned from
/// [Self::get_pub_inputs()](Prover::get_pub_inputs) for the provided trace.
fn prove(&self, trace: Self::Trace) -> Result<Proof, ProverError> {
async fn prove(&self, trace: Self::Trace) -> Result<Proof, ProverError>
where
<Self::Air as Air>::PublicInputs: Send,
<Self::Air as Air>::GkrProof: Send,
{
// figure out which version of the generic proof generation procedure to run. this is a sort
// of static dispatch for selecting two generic parameter: extension field and hash
// function.
match self.options().field_extension() {
FieldExtension::None => self.generate_proof::<Self::BaseField>(trace),
FieldExtension::None => self.generate_proof::<Self::BaseField>(trace).await,
FieldExtension::Quadratic => {
if !<QuadExtension<Self::BaseField>>::is_supported() {
return Err(ProverError::UnsupportedFieldExtension(2));
}
self.generate_proof::<QuadExtension<Self::BaseField>>(trace)
self.generate_proof::<QuadExtension<Self::BaseField>>(trace).await
}
FieldExtension::Cubic => {
if !<CubeExtension<Self::BaseField>>::is_supported() {
return Err(ProverError::UnsupportedFieldExtension(3));
}
self.generate_proof::<CubeExtension<Self::BaseField>>(trace)
self.generate_proof::<CubeExtension<Self::BaseField>>(trace).await
}
}
}
Expand All @@ -260,9 +269,11 @@ pub trait Prover {
/// execution `trace` is valid against this prover's AIR.
/// TODO: make this function un-callable externally?
#[doc(hidden)]
fn generate_proof<E>(&self, trace: Self::Trace) -> Result<Proof, ProverError>
async fn generate_proof<E>(&self, trace: Self::Trace) -> Result<Proof, ProverError>
where
E: FieldElement<BaseField = Self::BaseField>,
<Self::Air as Air>::PublicInputs: Send,
<Self::Air as Air>::GkrProof: Send,
{
// 0 ----- instantiate AIR and prover channel ---------------------------------------------

Expand Down Expand Up @@ -294,31 +305,16 @@ pub trait Prover {
assert_eq!(domain.trace_length(), trace_length);

// commit to the main trace segment
let (mut trace_lde, mut trace_polys) = {
// extend the main execution trace and build a Merkle tree from the extended trace
let span = info_span!("commit_to_main_trace_segment").entered();
let (trace_lde, trace_polys) =
self.new_trace_lde(trace.info(), trace.main_segment(), &domain);

// get the commitment to the main trace segment LDE
let main_trace_root = trace_lde.get_main_trace_commitment();

// commit to the LDE of the main trace by writing the root of its Merkle tree into
// the channel
channel.commit_trace(main_trace_root);

drop(span);

(trace_lde, trace_polys)
};
let (mut trace_lde, mut trace_polys) =
self.commit_to_main_trace_segment(&trace, &domain, &mut channel).await;

// build the auxiliary trace segment, and append the resulting segments to trace commitment
// and trace polynomial table structs
let aux_trace_with_metadata = if air.trace_info().is_multi_segment() {
let (gkr_proof, lagrange_rand_elements) =
if air.context().has_lagrange_kernel_aux_column() {
let (gkr_proof, lagrange_rand_elements) =
self.generate_gkr_proof(&trace, channel.public_coin());
self.generate_gkr_proof(&trace, channel.public_coin()).await;

(Some(gkr_proof), Some(lagrange_rand_elements))
} else {
Expand All @@ -333,7 +329,7 @@ pub trait Prover {
AuxRandElements::new_with_lagrange(rand_elements, lagrange_rand_elements)
};

let aux_trace = self.build_aux_trace(&trace, &aux_rand_elements);
let aux_trace = self.build_aux_trace(&trace, &aux_rand_elements).await;

// commit to the auxiliary trace segment
let aux_segment_polys = {
Expand Down Expand Up @@ -384,36 +380,16 @@ pub trait Prover {
// compute random linear combinations of these evaluations using coefficients drawn from
// the channel
let ce_domain_size = air.ce_domain_size();
let composition_poly_trace =
info_span!("evaluate_constraints", ce_domain_size).in_scope(|| {
self.new_evaluator(
&air,
aux_rand_elements,
channel.get_constraint_composition_coeffs(),
)
.evaluate(&trace_lde, &domain)
});
let composition_poly_trace = self
.new_evaluator(&air, aux_rand_elements, channel.get_constraint_composition_coeffs())
.await
.evaluate(&trace_lde, &domain);
assert_eq!(composition_poly_trace.num_rows(), ce_domain_size);

// 3 ----- commit to constraint evaluations -----------------------------------------------
let (constraint_commitment, composition_poly) = {
let span = info_span!("commit_to_constraint_evaluations").entered();

// first, build a commitment to the evaluations of the constraint composition
// polynomial columns
let (constraint_commitment, composition_poly) = self.build_constraint_commitment::<E>(
composition_poly_trace,
air.context().num_constraint_composition_columns(),
&domain,
);

// then, commit to the evaluations of constraints by writing the root of the constraint
// Merkle tree into the channel
channel.commit_constraints(constraint_commitment.root());

drop(span);
(constraint_commitment, composition_poly)
};
let (constraint_commitment, composition_poly) = self
.commit_to_constraint_evaluations(&air, composition_poly_trace, &domain, &mut channel)
.await;

// 4 ----- build DEEP composition polynomial ----------------------------------------------
let deep_composition_poly = {
Expand Down Expand Up @@ -540,7 +516,7 @@ pub trait Prover {
///
/// The commitment is computed by hashing each row in the evaluation matrix, and then building
/// a Merkle tree from the resulting hashes.
fn build_constraint_commitment<E>(
async fn build_constraint_commitment<E>(
&self,
composition_poly_trace: CompositionPolyTrace<E>,
num_constraint_composition_columns: usize,
Expand Down Expand Up @@ -584,4 +560,58 @@ pub trait Prover {

(constraint_commitment, composition_poly)
}

#[doc(hidden)]
#[instrument(skip_all)]
async fn commit_to_main_trace_segment<E>(
&self,
trace: &Self::Trace,
domain: &StarkDomain<Self::BaseField>,
channel: &mut ProverChannel<Self::Air, E, Self::HashFn, Self::RandomCoin>,
) -> (Self::TraceLde<E>, TracePolyTable<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
// extend the main execution trace and build a Merkle tree from the extended trace
let (trace_lde, trace_polys) =
self.new_trace_lde(trace.info(), trace.main_segment(), domain).await;

// get the commitment to the main trace segment LDE
let main_trace_root = trace_lde.get_main_trace_commitment();

// commit to the LDE of the main trace by writing the root of its Merkle tree into
// the channel
channel.commit_trace(main_trace_root);

(trace_lde, trace_polys)
}

#[doc(hidden)]
#[instrument(skip_all)]
async fn commit_to_constraint_evaluations<E>(
&self,
air: &Self::Air,
composition_poly_trace: CompositionPolyTrace<E>,
domain: &StarkDomain<Self::BaseField>,
channel: &mut ProverChannel<Self::Air, E, Self::HashFn, Self::RandomCoin>,
) -> (ConstraintCommitment<E, Self::HashFn>, CompositionPoly<E>)
where
E: FieldElement<BaseField = Self::BaseField>,
{
// first, build a commitment to the evaluations of the constraint composition polynomial
// columns
let (constraint_commitment, composition_poly) = self
.build_constraint_commitment::<E>(
composition_poly_trace,
air.context().num_constraint_composition_columns(),
domain,
)
.await;

// then, commit to the evaluations of constraints by writing the root of the constraint
// Merkle tree into the channel
channel.commit_constraints(constraint_commitment.root());

(constraint_commitment, composition_poly)
}
}
Loading

0 comments on commit 652a407

Please sign in to comment.