diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs
index c81e83f853..289f985c38 100644
--- a/crates/vm/src/arch/extensions.rs
+++ b/crates/vm/src/arch/extensions.rs
@@ -7,6 +7,8 @@ use std::{
 
 use derive_more::derive::From;
 use getset::Getters;
+#[cfg(feature = "bench-metrics")]
+use metrics::counter;
 use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
 use openvm_circuit_primitives::{
     utils::next_power_of_two_or_zero,
@@ -32,6 +34,8 @@ use super::{
     vm_poseidon2_config, ExecutionBus, InstructionExecutor, PhantomSubExecutor, Streams,
     SystemConfig, SystemTraceHeights,
 };
+#[cfg(feature = "bench-metrics")]
+use crate::metrics::VmMetrics;
 use crate::system::{
     connector::VmConnectorChip,
     memory::{
@@ -718,16 +722,22 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
     where
         P: AnyEnum,
     {
-        let chip = self
-            .inventory
-            .periphery
-            .get_mut(VmChipComplex::<F, E, P>::POSEIDON2_PERIPHERY_IDX);
-        let hasher: Option<&mut Poseidon2PeripheryChip<F>> = chip.map(|chip| {
-            chip.as_any_kind_mut()
+        if self.config.continuation_enabled {
+            let chip = self
+                .inventory
+                .periphery
+                .get_mut(Self::POSEIDON2_PERIPHERY_IDX)
+                .expect("Poseidon2 chip required for persistent memory");
+            let hasher: &mut Poseidon2PeripheryChip<F> = chip
+                .as_any_kind_mut()
                 .downcast_mut()
-                .expect("Poseidon2 chip required for persistent memory")
-        });
-        self.base.memory_controller.finalize(hasher);
+                .expect("Poseidon2 chip required for persistent memory");
+            self.base.memory_controller.finalize(Some(hasher))
+        } else {
+            self.base
+                .memory_controller
+                .finalize(None::<&mut Poseidon2PeripheryChip<F>>)
+        };
     }
 
     pub(crate) fn set_program(&mut self, program: Program<F>) {
@@ -933,39 +943,29 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
     pub(crate) fn generate_proof_input<SC: StarkGenericConfig>(
         mut self,
         cached_program: Option<CommittedTraceData<SC>>,
+        #[cfg(feature = "bench-metrics")] metrics: &mut VmMetrics,
     ) -> ProofInput<SC>
     where
         Domain<SC>: PolynomialSpace<Val = F>,
         E: Chip<SC>,
         P: AnyEnum + Chip<SC>,
     {
+        // System: Finalize memory.
+        self.finalize_memory();
+        #[cfg(feature = "bench-metrics")]
+        self.finalize_metrics(metrics);
+
         let has_pv_chip = self.public_values_chip_idx().is_some();
         // ATTENTION: The order of AIR proof input generation MUST be consistent with `airs`.
         let mut builder = VmProofInputBuilder::new();
         let SystemBase {
             range_checker_chip,
-            mut memory_controller,
+            memory_controller,
             connector_chip,
             program_chip,
             ..
         } = self.base;
 
-        // System: Finalize memory.
-        if self.config.continuation_enabled {
-            let chip = self
-                .inventory
-                .periphery
-                .get_mut(VmChipComplex::<F, E, P>::POSEIDON2_PERIPHERY_IDX)
-                .expect("Poseidon2 chip required for persistent memory");
-            let hasher: &mut Poseidon2PeripheryChip<F> = chip
-                .as_any_kind_mut()
-                .downcast_mut()
-                .expect("Poseidon2 chip required for persistent memory");
-            memory_controller.finalize(Some(hasher))
-        } else {
-            memory_controller.finalize(None::<&mut Poseidon2PeripheryChip<F>>)
-        };
-
         // System: Program Chip
         debug_assert_eq!(builder.curr_air_id, PROGRAM_AIR_ID);
         builder.add_air_proof_input(program_chip.generate_air_proof_input(cached_program));
@@ -1025,6 +1025,23 @@ impl<F: PrimeField32, E, P> VmChipComplex<F, E, P> {
 
         builder.build()
     }
+
+    #[cfg(feature = "bench-metrics")]
+    fn finalize_metrics(&mut self, metrics: &mut VmMetrics)
+    where
+        E: ChipUsageGetter,
+        P: ChipUsageGetter,
+    {
+        counter!("total_cycles").absolute(metrics.cycle_count as u64);
+        counter!("main_cells_used")
+            .absolute(self.current_trace_cells().into_iter().sum::<usize>() as u64);
+
+        if self.config.profiling {
+            metrics.chip_heights =
+                itertools::izip!(self.air_names(), self.current_trace_heights()).collect();
+            metrics.emit();
+        }
+    }
 }
 
 struct VmProofInputBuilder<SC: StarkGenericConfig> {
diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs
index 05654f275d..a812f63499 100644
--- a/crates/vm/src/arch/segment.rs
+++ b/crates/vm/src/arch/segment.rs
@@ -205,7 +205,7 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
 
     /// Generate ProofInput to prove the segment. Should be called after ::execute
     pub fn generate_proof_input<SC: StarkGenericConfig>(
-        self,
+        #[allow(unused_mut)] mut self,
         cached_program: Option<CommittedTraceData<SC>>,
     ) -> ProofInput<SC>
     where
@@ -214,7 +214,11 @@ impl<F: PrimeField32, VC: VmConfig<F>> ExecutionSegment<F, VC> {
         VC::Periphery: Chip<SC>,
     {
         metrics_span("trace_gen_time_ms", || {
-            self.chip_complex.generate_proof_input(cached_program)
+            self.chip_complex.generate_proof_input(
+                cached_program,
+                #[cfg(feature = "bench-metrics")]
+                &mut self.metrics,
+            )
         })
     }
 
diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs
index 011d8bd96b..240e9b16ab 100644
--- a/crates/vm/src/arch/vm.rs
+++ b/crates/vm/src/arch/vm.rs
@@ -129,8 +129,6 @@ where
             // Used to add `segment` label to metrics
             let _span = info_span!("execute_segment", segment = segments.len()).entered();
             let state = metrics_span("execute_time_ms", || segment.execute_from_pc(pc))?;
-            #[cfg(feature = "bench-metrics")]
-            segment.finalize_metrics();
             pc = state.pc;
 
             if state.is_terminated {
@@ -354,8 +352,6 @@ where
             segment.set_override_trace_heights(overridden_heights.clone());
         }
         metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?;
-        #[cfg(feature = "bench-metrics")]
-        segment.finalize_metrics();
         Ok(segment)
     }
 }
diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs
index 3255cc57b4..c9815f5c32 100644
--- a/crates/vm/src/metrics/mod.rs
+++ b/crates/vm/src/metrics/mod.rs
@@ -59,20 +59,6 @@ where
             self.metrics.update_current_fn(pc);
         }
     }
-
-    pub fn finalize_metrics(&mut self) {
-        self.chip_complex.finalize_memory();
-
-        counter!("total_cycles").absolute(self.metrics.cycle_count as u64);
-        counter!("main_cells_used")
-            .absolute(self.current_trace_cells().into_iter().sum::<usize>() as u64);
-
-        if self.system_config().profiling {
-            self.metrics.chip_heights =
-                itertools::izip!(self.air_names.clone(), self.current_trace_heights()).collect();
-            self.metrics.emit();
-        }
-    }
 }
 
 impl VmMetrics {