Skip to content

Commit

Permalink
move finalize_metrics to generate_proof_input
Browse files Browse the repository at this point in the history
  • Loading branch information
zlangley committed Jan 6, 2025
1 parent 5205010 commit f679387
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 46 deletions.
69 changes: 43 additions & 26 deletions crates/vm/src/arch/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use std::{

use derive_more::derive::From;
use getset::Getters;
#[cfg(feature = "bench-metrics")]
use metrics::counter;
use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
use openvm_circuit_primitives::{
utils::next_power_of_two_or_zero,
Expand All @@ -32,6 +34,8 @@ use super::{
vm_poseidon2_config, ExecutionBus, InstructionExecutor, PhantomSubExecutor, Streams,
SystemConfig, SystemTraceHeights,
};
#[cfg(feature = "bench-metrics")]
use crate::metrics::VmMetrics;
use crate::system::{
connector::VmConnectorChip,
memory::{
Expand Down Expand Up @@ -718,16 +722,22 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
where
P: AnyEnum,
{
let chip = self
.inventory
.periphery
.get_mut(VmChipComplex::<F, E, P>::POSEIDON2_PERIPHERY_IDX);
let hasher: Option<&mut Poseidon2PeripheryChip<F>> = chip.map(|chip| {
chip.as_any_kind_mut()
if self.config.continuation_enabled {
let chip = self
.inventory
.periphery
.get_mut(Self::POSEIDON2_PERIPHERY_IDX)
.expect("Poseidon2 chip required for persistent memory");
let hasher: &mut Poseidon2PeripheryChip<F> = chip
.as_any_kind_mut()
.downcast_mut()
.expect("Poseidon2 chip required for persistent memory")
});
self.base.memory_controller.finalize(hasher);
.expect("Poseidon2 chip required for persistent memory");
self.base.memory_controller.finalize(Some(hasher))
} else {
self.base
.memory_controller
.finalize(None::<&mut Poseidon2PeripheryChip<F>>)
};
}

pub(crate) fn set_program(&mut self, program: Program<F>) {
Expand Down Expand Up @@ -933,39 +943,29 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
pub(crate) fn generate_proof_input<SC: StarkGenericConfig>(
mut self,
cached_program: Option<CommittedTraceData<SC>>,
#[cfg(feature = "bench-metrics")] metrics: &mut VmMetrics,
) -> ProofInput<SC>
where
Domain<SC>: PolynomialSpace<Val = F>,
E: Chip<SC>,
P: AnyEnum + Chip<SC>,
{
// System: Finalize memory.
self.finalize_memory();
#[cfg(feature = "bench-metrics")]
self.finalize_metrics(metrics);

let has_pv_chip = self.public_values_chip_idx().is_some();
// ATTENTION: The order of AIR proof input generation MUST be consistent with `airs`.
let mut builder = VmProofInputBuilder::new();
let SystemBase {
range_checker_chip,
mut memory_controller,
memory_controller,
connector_chip,
program_chip,
..
} = self.base;

// System: Finalize memory.
if self.config.continuation_enabled {
let chip = self
.inventory
.periphery
.get_mut(VmChipComplex::<F, E, P>::POSEIDON2_PERIPHERY_IDX)
.expect("Poseidon2 chip required for persistent memory");
let hasher: &mut Poseidon2PeripheryChip<F> = chip
.as_any_kind_mut()
.downcast_mut()
.expect("Poseidon2 chip required for persistent memory");
memory_controller.finalize(Some(hasher))
} else {
memory_controller.finalize(None::<&mut Poseidon2PeripheryChip<F>>)
};

// System: Program Chip
debug_assert_eq!(builder.curr_air_id, PROGRAM_AIR_ID);
builder.add_air_proof_input(program_chip.generate_air_proof_input(cached_program));
Expand Down Expand Up @@ -1025,6 +1025,23 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {

builder.build()
}

#[cfg(feature = "bench-metrics")]
fn finalize_metrics(&mut self, metrics: &mut VmMetrics)
where
E: ChipUsageGetter,
P: ChipUsageGetter,
{
counter!("total_cycles").absolute(metrics.cycle_count as u64);
counter!("main_cells_used")
.absolute(self.current_trace_cells().into_iter().sum::<usize>() as u64);

if self.config.profiling {
metrics.chip_heights =
itertools::izip!(self.air_names(), self.current_trace_heights()).collect();
metrics.emit();
}
}
}

struct VmProofInputBuilder<SC: StarkGenericConfig> {
Expand Down
8 changes: 6 additions & 2 deletions crates/vm/src/arch/segment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {

/// Generate ProofInput to prove the segment. Should be called after ::execute
pub fn generate_proof_input<SC: StarkGenericConfig>(
self,
#[allow(unused_mut)] mut self,
cached_program: Option<CommittedTraceData<SC>>,
) -> ProofInput<SC>
where
Expand All @@ -214,7 +214,11 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
VC::Periphery: Chip<SC>,
{
metrics_span("trace_gen_time_ms", || {
self.chip_complex.generate_proof_input(cached_program)
self.chip_complex.generate_proof_input(
cached_program,
#[cfg(feature = "bench-metrics")]
&mut self.metrics,
)
})
}

Expand Down
4 changes: 0 additions & 4 deletions crates/vm/src/arch/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,6 @@ where
// Used to add `segment` label to metrics
let _span = info_span!("execute_segment", segment = segments.len()).entered();
let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?;
#[cfg(feature = "bench-metrics")]
segment.finalize_metrics();
pc = state.pc;

if state.is_terminated {
Expand Down Expand Up @@ -354,8 +352,6 @@ where
segment.set_override_trace_heights(overridden_heights.clone());
}
metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?;
#[cfg(feature = "bench-metrics")]
segment.finalize_metrics();
Ok(segment)
}
}
Expand Down
14 changes: 0 additions & 14 deletions crates/vm/src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ where
self.metrics.update_current_fn(pc);
}
}

pub fn finalize_metrics(&mut self) {
self.chip_complex.finalize_memory();

counter!("total_cycles").absolute(self.metrics.cycle_count as u64);
counter!("main_cells_used")
.absolute(self.current_trace_cells().into_iter().sum::<usize>() as u64);

if self.system_config().profiling {
self.metrics.chip_heights =
itertools::izip!(self.air_names.clone(), self.current_trace_heights()).collect();
self.metrics.emit();
}
}
}

impl VmMetrics {
Expand Down

0 comments on commit f679387

Please sign in to comment.