From f6793873c29dd9602501ee9c3d1eb4f04c7061df Mon Sep 17 00:00:00 2001 From: Zach Langley Date: Mon, 6 Jan 2025 12:35:52 -0500 Subject: [PATCH] move finalize_metrics to generate_proof_input --- crates/vm/src/arch/extensions.rs | 69 ++++++++++++++++++++------------ crates/vm/src/arch/segment.rs | 8 +++- crates/vm/src/arch/vm.rs | 4 -- crates/vm/src/metrics/mod.rs | 14 ------- 4 files changed, 49 insertions(+), 46 deletions(-) diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index c81e83f85..289f985c3 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -7,6 +7,8 @@ use std::{ use derive_more::derive::From; use getset::Getters; +#[cfg(feature = "bench-metrics")] +use metrics::counter; use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; use openvm_circuit_primitives::{ utils::next_power_of_two_or_zero, @@ -32,6 +34,8 @@ use super::{ vm_poseidon2_config, ExecutionBus, InstructionExecutor, PhantomSubExecutor, Streams, SystemConfig, SystemTraceHeights, }; +#[cfg(feature = "bench-metrics")] +use crate::metrics::VmMetrics; use crate::system::{ connector::VmConnectorChip, memory::{ @@ -718,16 +722,22 @@ impl VmChipComplex { where P: AnyEnum, { - let chip = self - .inventory - .periphery - .get_mut(VmChipComplex::::POSEIDON2_PERIPHERY_IDX); - let hasher: Option<&mut Poseidon2PeripheryChip> = chip.map(|chip| { - chip.as_any_kind_mut() + if self.config.continuation_enabled { + let chip = self + .inventory + .periphery + .get_mut(Self::POSEIDON2_PERIPHERY_IDX) + .expect("Poseidon2 chip required for persistent memory"); + let hasher: &mut Poseidon2PeripheryChip = chip + .as_any_kind_mut() .downcast_mut() - .expect("Poseidon2 chip required for persistent memory") - }); - self.base.memory_controller.finalize(hasher); + .expect("Poseidon2 chip required for persistent memory"); + self.base.memory_controller.finalize(Some(hasher)) + } else { + self.base + .memory_controller + .finalize(None::<&mut Poseidon2PeripheryChip>) + }; } pub(crate) fn set_program(&mut self, program: Program) { @@ -933,39 +943,29 @@ impl VmChipComplex { pub(crate) fn generate_proof_input( mut self, cached_program: Option>, + #[cfg(feature = "bench-metrics")] metrics: &mut VmMetrics, ) -> ProofInput where Domain: PolynomialSpace, E: Chip, P: AnyEnum + Chip, { + // System: Finalize memory. + self.finalize_memory(); + #[cfg(feature = "bench-metrics")] + self.finalize_metrics(metrics); + let has_pv_chip = self.public_values_chip_idx().is_some(); // ATTENTION: The order of AIR proof input generation MUST be consistent with `airs`. let mut builder = VmProofInputBuilder::new(); let SystemBase { range_checker_chip, - mut memory_controller, + memory_controller, connector_chip, program_chip, .. } = self.base; - // System: Finalize memory. - if self.config.continuation_enabled { - let chip = self - .inventory - .periphery - .get_mut(VmChipComplex::::POSEIDON2_PERIPHERY_IDX) - .expect("Poseidon2 chip required for persistent memory"); - let hasher: &mut Poseidon2PeripheryChip = chip - .as_any_kind_mut() - .downcast_mut() - .expect("Poseidon2 chip required for persistent memory"); - memory_controller.finalize(Some(hasher)) - } else { - memory_controller.finalize(None::<&mut Poseidon2PeripheryChip>) - }; - // System: Program Chip debug_assert_eq!(builder.curr_air_id, PROGRAM_AIR_ID); builder.add_air_proof_input(program_chip.generate_air_proof_input(cached_program)); @@ -1025,6 +1025,23 @@ impl VmChipComplex { builder.build() } + + #[cfg(feature = "bench-metrics")] + fn finalize_metrics(&mut self, metrics: &mut VmMetrics) + where + E: ChipUsageGetter, + P: ChipUsageGetter, + { + counter!("total_cycles").absolute(metrics.cycle_count as u64); + counter!("main_cells_used") + .absolute(self.current_trace_cells().into_iter().sum::() as u64); + + if self.config.profiling { + metrics.chip_heights = + itertools::izip!(self.air_names(), self.current_trace_heights()).collect(); + metrics.emit(); + } + } } struct VmProofInputBuilder { diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 05654f275..a812f6349 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -205,7 +205,7 @@ impl> ExecutionSegment { /// Generate ProofInput to prove the segment. Should be called after ::execute pub fn generate_proof_input( - self, + #[allow(unused_mut)] mut self, cached_program: Option>, ) -> ProofInput where @@ -214,7 +214,11 @@ impl> ExecutionSegment { VC::Periphery: Chip, { metrics_span("trace_gen_time_ms", || { - self.chip_complex.generate_proof_input(cached_program) + self.chip_complex.generate_proof_input( + cached_program, + #[cfg(feature = "bench-metrics")] + &mut self.metrics, + ) }) } diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index 011d8bd96..240e9b16a 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -129,8 +129,6 @@ where // Used to add `segment` label to metrics let _span = info_span!("execute_segment", segment = segments.len()).entered(); let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?; - #[cfg(feature = "bench-metrics")] - segment.finalize_metrics(); pc = state.pc; if state.is_terminated { @@ -354,8 +352,6 @@ where segment.set_override_trace_heights(overridden_heights.clone()); } metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?; - #[cfg(feature = "bench-metrics")] - segment.finalize_metrics(); Ok(segment) } } diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 3255cc57b..c9815f5c3 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -59,20 +59,6 @@ where self.metrics.update_current_fn(pc); } } - - pub fn finalize_metrics(&mut self) { - self.chip_complex.finalize_memory(); - - counter!("total_cycles").absolute(self.metrics.cycle_count as u64); - counter!("main_cells_used") - .absolute(self.current_trace_cells().into_iter().sum::() as u64); - - if self.system_config().profiling { - self.metrics.chip_heights = - itertools::izip!(self.air_names.clone(), self.current_trace_heights()).collect(); - self.metrics.emit(); - } - } } impl VmMetrics {