Skip to content

Commit

Permalink
feat: new device proving key API (#1737)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamirhemo authored Nov 5, 2024
1 parent 97bb83d commit da8a772
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 34 deletions.
9 changes: 7 additions & 2 deletions crates/core/machine/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ where
{
let all_records = all_records_rx.iter().flatten().collect::<Vec<_>>();
let mut challenger = prover.machine().config().challenger();
prover.machine().debug_constraints(&pk.to_host(), all_records, &mut challenger);
let pk_host = prover.pk_to_host(pk);
prover.machine().debug_constraints(&pk_host, all_records, &mut challenger);
}

Ok((proof, public_values_stream, cycles))
Expand Down Expand Up @@ -798,7 +799,11 @@ where
let prove_span = tracing::debug_span!("prove").entered();

#[cfg(feature = "debug")]
prover.machine().debug_constraints(&pk.to_host(), records.clone(), &mut challenger.clone());
prover.machine().debug_constraints(
&prover.pk_to_host(&pk),
records.clone(),
&mut challenger.clone(),
);

let proof = prover.prove(&pk, records, &mut challenger, SP1CoreOpts::default()).unwrap();
prove_span.exit();
Expand Down
33 changes: 17 additions & 16 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,11 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
let program = self.get_program(elf).unwrap();
let (pk, vk) = self.core_prover.setup(&program);
let vk = SP1VerifyingKey { vk };
let pk = SP1ProvingKey { pk: pk.to_host(), elf: elf.to_vec(), vk: vk.clone() };
let pk = SP1ProvingKey {
pk: self.core_prover.pk_to_host(&pk),
elf: elf.to_vec(),
vk: vk.clone(),
};
(pk, vk)
}

Expand Down Expand Up @@ -292,20 +296,17 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
) -> Result<SP1CoreProof, SP1CoreProverError> {
context.subproof_verifier.replace(Arc::new(self));
let program = self.get_program(&pk.elf).unwrap();
let (proof, public_values_stream, cycles) = sp1_core_machine::utils::prove_with_context::<
_,
C::CoreProver,
>(
&self.core_prover,
&<C::CoreProver as MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>::DeviceProvingKey::from_host(
&pk.pk,
),
program,
stdin,
opts.core_opts,
context,
self.core_shape_config.as_ref(),
)?;
let pk = self.core_prover.pk_to_device(&pk.pk);
let (proof, public_values_stream, cycles) =
sp1_core_machine::utils::prove_with_context::<_, C::CoreProver>(
&self.core_prover,
&pk,
program,
stdin,
opts.core_opts,
context,
self.core_shape_config.as_ref(),
)?;
Self::check_for_high_cycles(cycles);
let public_values = SP1PublicValues::from(&public_values_stream);
Ok(SP1CoreProof {
Expand Down Expand Up @@ -819,7 +820,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {

#[cfg(feature = "debug")]
self.compress_prover.debug_constraints(
&pk.to_host(),
&self.compress_prover.pk_to_host(&pk),
vec![record.clone()],
&mut challenger.clone(),
);
Expand Down
3 changes: 1 addition & 2 deletions crates/recursion/circuit/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ pub(crate) mod tests {
use sp1_recursion_core::{machine::RecursionAir, Runtime};
use sp1_stark::{
baby_bear_poseidon2::BabyBearPoseidon2, CpuProver, InnerChallenge, InnerVal, MachineProver,
MachineProvingKey,
};

use crate::witness::WitnessBlock;
Expand Down Expand Up @@ -145,8 +144,8 @@ pub(crate) mod tests {
let proof_wide_span = tracing::debug_span!("Run test with wide machine").entered();
let wide_machine = RecursionAir::<_, 3>::compress_machine(SC::default());
let (pk, vk) = wide_machine.setup(&program);
let pk = P::DeviceProvingKey::from_host(&pk);
let prover = P::new(wide_machine);
let pk = prover.pk_to_device(&pk);
let result = run_test_machine_with_prover::<_, _, P>(&prover, records.clone(), pk, vk);
proof_wide_span.exit();

Expand Down
28 changes: 14 additions & 14 deletions crates/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ pub trait MachineProver<SC: StarkGenericConfig, A: MachineAir<SC::Val>>:
/// Setup the preprocessed data into a proving and verifying key.
fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>);

/// Copy the proving key from the host to the device.
fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey;

/// Copy the proving key from the device to the host.
fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC>;

/// Generate the main traces.
fn generate_traces(
&self,
Expand Down Expand Up @@ -273,12 +279,6 @@ pub trait MachineProvingKey<SC: StarkGenericConfig>: Send + Sync {
/// The start pc.
fn pc_start(&self) -> Val<SC>;

/// The proving key on the host.
fn to_host(&self) -> StarkProvingKey<SC>;

/// The proving key on the device.
fn from_host(host: &StarkProvingKey<SC>) -> Self;

/// Observe itself in the challenger.
fn observe_into(&self, challenger: &mut Challenger<SC>);
}
Expand Down Expand Up @@ -323,6 +323,14 @@ where
self.machine().setup(program)
}

fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey {
pk.clone()
}

fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC> {
pk.clone()
}

fn commit(
&self,
record: &A::Record,
Expand Down Expand Up @@ -889,14 +897,6 @@ where
self.pc_start
}

fn to_host(&self) -> StarkProvingKey<SC> {
self.clone()
}

fn from_host(host: &StarkProvingKey<SC>) -> Self {
host.clone()
}

fn observe_into(&self, challenger: &mut Challenger<SC>) {
challenger.observe(self.commit.clone());
challenger.observe(self.pc_start);
Expand Down

0 comments on commit da8a772

Please sign in to comment.