Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

miri: optimize zeroed alloc #136035

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use super::error::*;
use crate::errors::{LongRunning, LongRunningWarn};
use crate::fluent_generated as fluent;
use crate::interpret::{
self, AllocId, AllocRange, ConstAllocation, CtfeProvenance, FnArg, Frame, GlobalAlloc, ImmTy,
InterpCx, InterpResult, MPlaceTy, OpTy, RangeSet, Scalar, compile_time_machine, interp_ok,
throw_exhaust, throw_inval, throw_ub, throw_ub_custom, throw_unsup, throw_unsup_format,
self, AllocId, AllocInit, AllocRange, ConstAllocation, CtfeProvenance, FnArg, Frame,
GlobalAlloc, ImmTy, InterpCx, InterpResult, MPlaceTy, OpTy, RangeSet, Scalar,
compile_time_machine, interp_ok, throw_exhaust, throw_inval, throw_ub, throw_ub_custom,
throw_unsup, throw_unsup_format,
};

/// When hitting this many interpreted terminators we emit a deny by default lint
Expand Down Expand Up @@ -420,6 +421,7 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> {
Size::from_bytes(size),
align,
interpret::MemoryKind::Machine(MemoryKind::Heap),
AllocInit::Uninit,
)?;
ecx.write_pointer(ptr, dest)?;
}
Expand Down
21 changes: 14 additions & 7 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
use tracing::{debug, instrument, trace};

use super::{
AllocBytes, AllocId, AllocMap, AllocRange, Allocation, CheckAlignMsg, CheckInAllocMsg,
CtfeProvenance, GlobalAlloc, InterpCx, InterpResult, Machine, MayLeak, Misalignment, Pointer,
PointerArithmetic, Provenance, Scalar, alloc_range, err_ub, err_ub_custom, interp_ok, throw_ub,
throw_ub_custom, throw_unsup, throw_unsup_format,
AllocBytes, AllocId, AllocInit, AllocMap, AllocRange, Allocation, CheckAlignMsg,
CheckInAllocMsg, CtfeProvenance, GlobalAlloc, InterpCx, InterpResult, Machine, MayLeak,
Misalignment, Pointer, PointerArithmetic, Provenance, Scalar, alloc_range, err_ub,
err_ub_custom, interp_ok, throw_ub, throw_ub_custom, throw_unsup, throw_unsup_format,
};
use crate::fluent_generated as fluent;

Expand Down Expand Up @@ -230,11 +230,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
size: Size,
align: Align,
kind: MemoryKind<M::MemoryKind>,
init: AllocInit,
) -> InterpResult<'tcx, Pointer<M::Provenance>> {
let alloc = if M::PANIC_ON_ALLOC_FAIL {
Allocation::uninit(size, align)
Allocation::new(size, align, init)
} else {
Allocation::try_uninit(size, align)?
Allocation::try_new(size, align, init)?
};
self.insert_allocation(alloc, kind)
}
Expand Down Expand Up @@ -270,13 +271,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
M::adjust_alloc_root_pointer(self, Pointer::from(id), Some(kind))
}

/// If this grows the allocation, `init_growth` determines
/// whether the additional space will be initialized.
pub fn reallocate_ptr(
&mut self,
ptr: Pointer<Option<M::Provenance>>,
old_size_and_align: Option<(Size, Align)>,
new_size: Size,
new_align: Align,
kind: MemoryKind<M::MemoryKind>,
init_growth: AllocInit,
) -> InterpResult<'tcx, Pointer<M::Provenance>> {
let (alloc_id, offset, _prov) = self.ptr_get_alloc_id(ptr, 0)?;
if offset.bytes() != 0 {
Expand All @@ -289,7 +293,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {

// For simplicities' sake, we implement reallocate as "alloc, copy, dealloc".
// This happens so rarely, the perf advantage is outweighed by the maintenance cost.
let new_ptr = self.allocate_ptr(new_size, new_align, kind)?;
// If requested, we zero-init the entire allocation, to ensure that a growing
// allocation has its new bytes properly set. For the part that is copied,
// `mem_copy` below will de-initialize things as necessary.
let new_ptr = self.allocate_ptr(new_size, new_align, kind, init_growth)?;
let old_size = match old_size_and_align {
Some((size, _align)) => size,
None => self.get_alloc_raw(alloc_id)?.size(),
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use rustc_middle::{bug, mir, span_bug};
use tracing::{instrument, trace};

use super::{
AllocRef, AllocRefMut, CheckAlignMsg, CtfeProvenance, ImmTy, Immediate, InterpCx, InterpResult,
Machine, MemoryKind, Misalignment, OffsetMode, OpTy, Operand, Pointer, Projectable, Provenance,
Scalar, alloc_range, interp_ok, mir_assign_valid_types,
AllocInit, AllocRef, AllocRefMut, CheckAlignMsg, CtfeProvenance, ImmTy, Immediate, InterpCx,
InterpResult, Machine, MemoryKind, Misalignment, OffsetMode, OpTy, Operand, Pointer,
Projectable, Provenance, Scalar, alloc_range, interp_ok, mir_assign_valid_types,
};

#[derive(Copy, Clone, Hash, PartialEq, Eq, Debug)]
Expand Down Expand Up @@ -983,7 +983,7 @@ where
let Some((size, align)) = self.size_and_align_of(&meta, &layout)? else {
span_bug!(self.cur_span(), "cannot allocate space for `extern` type, size is not known")
};
let ptr = self.allocate_ptr(size, align, kind)?;
let ptr = self.allocate_ptr(size, align, kind, AllocInit::Uninit)?;
interp_ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout, /*unaligned*/ false))
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::ControlFlow;

use rustc_hir::def_id::LocalDefId;
use rustc_middle::mir;
use rustc_middle::mir::interpret::{Allocation, InterpResult, Pointer};
use rustc_middle::mir::interpret::{AllocInit, Allocation, InterpResult, Pointer};
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{
self, Ty, TyCtxt, TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor,
Expand Down Expand Up @@ -76,7 +76,7 @@ pub(crate) fn create_static_alloc<'tcx>(
static_def_id: LocalDefId,
layout: TyAndLayout<'tcx>,
) -> InterpResult<'tcx, MPlaceTy<'tcx>> {
let alloc = Allocation::try_uninit(layout.size, layout.align.abi)?;
let alloc = Allocation::try_new(layout.size, layout.align.abi, AllocInit::Uninit)?;
let alloc_id = ecx.tcx.reserve_and_set_static_alloc(static_def_id.into());
assert_eq!(ecx.machine.static_root_ids, None);
ecx.machine.static_root_ids = Some((alloc_id, static_def_id));
Expand Down
26 changes: 20 additions & 6 deletions compiler/rustc_middle/src/mir/interpret/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ impl AllocRange {
}
}

/// Whether a new allocation should be initialized with zero-bytes.
pub enum AllocInit {
Uninit,
Zero,
}

// The constructors are all without extra; the extra gets added by a machine hook later.
impl<Prov: Provenance, Bytes: AllocBytes> Allocation<Prov, (), Bytes> {
/// Creates an allocation initialized by the given bytes
Expand All @@ -294,7 +300,12 @@ impl<Prov: Provenance, Bytes: AllocBytes> Allocation<Prov, (), Bytes> {
Allocation::from_bytes(slice, Align::ONE, Mutability::Not)
}

fn uninit_inner<R>(size: Size, align: Align, fail: impl FnOnce() -> R) -> Result<Self, R> {
fn new_inner<R>(
size: Size,
align: Align,
init: AllocInit,
fail: impl FnOnce() -> R,
) -> Result<Self, R> {
// We raise an error if we cannot create the allocation on the host.
// This results in an error that can happen non-deterministically, since the memory
// available to the compiler can change between runs. Normally queries are always
Expand All @@ -306,7 +317,10 @@ impl<Prov: Provenance, Bytes: AllocBytes> Allocation<Prov, (), Bytes> {
Ok(Allocation {
bytes,
provenance: ProvenanceMap::new(),
init_mask: InitMask::new(size, false),
init_mask: InitMask::new(size, match init {
AllocInit::Uninit => false,
AllocInit::Zero => true,
}),
align,
mutability: Mutability::Mut,
extra: (),
Expand All @@ -315,8 +329,8 @@ impl<Prov: Provenance, Bytes: AllocBytes> Allocation<Prov, (), Bytes> {

/// Try to create an Allocation of `size` bytes, failing if there is not enough memory
/// available to the compiler to do so.
pub fn try_uninit<'tcx>(size: Size, align: Align) -> InterpResult<'tcx, Self> {
Self::uninit_inner(size, align, || {
pub fn try_new<'tcx>(size: Size, align: Align, init: AllocInit) -> InterpResult<'tcx, Self> {
Self::new_inner(size, align, init, || {
ty::tls::with(|tcx| tcx.dcx().delayed_bug("exhausted memory during interpretation"));
InterpErrorKind::ResourceExhaustion(ResourceExhaustionInfo::MemoryExhausted)
})
Expand All @@ -328,8 +342,8 @@ impl<Prov: Provenance, Bytes: AllocBytes> Allocation<Prov, (), Bytes> {
///
/// Example use case: To obtain an Allocation filled with specific data,
/// first call this function and then call write_scalar to fill in the right data.
pub fn uninit(size: Size, align: Align) -> Self {
match Self::uninit_inner(size, align, || {
pub fn new(size: Size, align: Align, init: AllocInit) -> Self {
match Self::new_inner(size, align, init, || {
panic!(
"interpreter ran out of memory: cannot create allocation of {} bytes",
size.bytes()
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/mir/interpret/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ pub use {
};

pub use self::allocation::{
AllocBytes, AllocError, AllocRange, AllocResult, Allocation, ConstAllocation, InitChunk,
InitChunkIter, alloc_range,
AllocBytes, AllocError, AllocInit, AllocRange, AllocResult, Allocation, ConstAllocation,
InitChunk, InitChunkIter, alloc_range,
};
pub use self::error::{
BadBytesAccess, CheckAlignMsg, CheckInAllocMsg, ErrorHandled, EvalStaticInitializerRawResult,
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_middle/src/ty/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use rustc_ast::Mutability;
use rustc_macros::HashStable;
use rustc_type_ir::elaborate;

use crate::mir::interpret::{AllocId, Allocation, CTFE_ALLOC_SALT, Pointer, Scalar, alloc_range};
use crate::mir::interpret::{
AllocId, AllocInit, Allocation, CTFE_ALLOC_SALT, Pointer, Scalar, alloc_range,
};
use crate::ty::{self, Instance, PolyTraitRef, Ty, TyCtxt};

#[derive(Clone, Copy, PartialEq, HashStable)]
Expand Down Expand Up @@ -108,7 +110,7 @@ pub(super) fn vtable_allocation_provider<'tcx>(
let ptr_align = tcx.data_layout.pointer_align.abi;

let vtable_size = ptr_size * u64::try_from(vtable_entries.len()).unwrap();
let mut vtable = Allocation::uninit(vtable_size, ptr_align);
let mut vtable = Allocation::new(vtable_size, ptr_align, AllocInit::Uninit);

// No need to do any alignment checks on the memory accesses below, because we know the
// allocation is correctly aligned as we created it above. Also we're only offsetting by
Expand Down
12 changes: 8 additions & 4 deletions compiler/rustc_smir/src/rustc_smir/alloc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_abi::{Align, Size};
use rustc_middle::mir::ConstValue;
use rustc_middle::mir::interpret::{AllocRange, Pointer, alloc_range};
use rustc_middle::mir::interpret::{AllocInit, AllocRange, Pointer, alloc_range};
use stable_mir::Error;
use stable_mir::mir::Mutability;
use stable_mir::ty::{Allocation, ProvenanceMap};
Expand Down Expand Up @@ -44,7 +44,8 @@ pub(crate) fn try_new_allocation<'tcx>(
.layout_of(rustc_middle::ty::TypingEnv::fully_monomorphized().as_query_input(ty))
.map_err(|e| e.stable(tables))?
.align;
let mut allocation = rustc_middle::mir::interpret::Allocation::uninit(size, align.abi);
let mut allocation =
rustc_middle::mir::interpret::Allocation::new(size, align.abi, AllocInit::Uninit);
allocation
.write_scalar(&tables.tcx, alloc_range(Size::ZERO, size), scalar)
.map_err(|e| e.stable(tables))?;
Expand All @@ -68,8 +69,11 @@ pub(crate) fn try_new_allocation<'tcx>(
.tcx
.layout_of(rustc_middle::ty::TypingEnv::fully_monomorphized().as_query_input(ty))
.map_err(|e| e.stable(tables))?;
let mut allocation =
rustc_middle::mir::interpret::Allocation::uninit(layout.size, layout.align.abi);
let mut allocation = rustc_middle::mir::interpret::Allocation::new(
layout.size,
layout.align.abi,
AllocInit::Uninit,
);
allocation
.write_scalar(
&tables.tcx,
Expand Down
19 changes: 6 additions & 13 deletions src/tools/miri/src/shims/alloc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::iter;

use rustc_abi::{Align, Size};
use rustc_ast::expand::allocator::AllocatorKind;

Expand Down Expand Up @@ -80,18 +78,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
}

fn malloc(&mut self, size: u64, zero_init: bool) -> InterpResult<'tcx, Pointer> {
fn malloc(&mut self, size: u64, init: AllocInit) -> InterpResult<'tcx, Pointer> {
let this = self.eval_context_mut();
let align = this.malloc_align(size);
let ptr = this.allocate_ptr(Size::from_bytes(size), align, MiriMemoryKind::C.into())?;
if zero_init {
// We just allocated this, the access is definitely in-bounds and fits into our address space.
this.write_bytes_ptr(
ptr.into(),
iter::repeat(0u8).take(usize::try_from(size).unwrap()),
)
.unwrap();
}
let ptr = this.allocate_ptr(Size::from_bytes(size), align, MiriMemoryKind::C.into(), init)?;
interp_ok(ptr.into())
}

Expand All @@ -115,6 +105,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(size),
Align::from_bytes(align).unwrap(),
MiriMemoryKind::C.into(),
AllocInit::Uninit
)?;
this.write_pointer(ptr, &memptr)?;
interp_ok(Scalar::from_i32(0))
Expand All @@ -134,7 +125,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let new_align = this.malloc_align(new_size);
if this.ptr_is_null(old_ptr)? {
// Here we must behave like `malloc`.
self.malloc(new_size, /*zero_init*/ false)
self.malloc(new_size, AllocInit::Uninit)
} else {
if new_size == 0 {
// C, in their infinite wisdom, made this UB.
Expand All @@ -147,6 +138,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(new_size),
new_align,
MiriMemoryKind::C.into(),
AllocInit::Uninit
)?;
interp_ok(new_ptr.into())
}
Expand Down Expand Up @@ -187,6 +179,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(size),
Align::from_bytes(align).unwrap(),
MiriMemoryKind::C.into(),
AllocInit::Uninit
)?;
interp_ok(ptr.into())
}
Expand Down
16 changes: 6 additions & 10 deletions src/tools/miri/src/shims/foreign_items.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::hash_map::Entry;
use std::io::Write;
use std::iter;
use std::path::Path;

use rustc_abi::{Align, AlignFromBytesError, Size};
Expand All @@ -9,6 +8,7 @@ use rustc_ast::expand::allocator::alloc_error_handler_name;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::CrateNum;
use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrFlags;
use rustc_middle::mir::interpret::AllocInit;
use rustc_middle::ty::Ty;
use rustc_middle::{mir, ty};
use rustc_span::Symbol;
Expand Down Expand Up @@ -442,7 +442,7 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
let [size] = this.check_shim(abi, Conv::C, link_name, args)?;
let size = this.read_target_usize(size)?;
if size <= this.max_size_of_val().bytes() {
let res = this.malloc(size, /*zero_init:*/ false)?;
let res = this.malloc(size, AllocInit::Uninit)?;
this.write_pointer(res, dest)?;
} else {
// If this does not fit in an isize, return null and, on Unix, set errno.
Expand All @@ -457,7 +457,7 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
let items = this.read_target_usize(items)?;
let elem_size = this.read_target_usize(elem_size)?;
if let Some(size) = this.compute_size_in_bytes(Size::from_bytes(elem_size), items) {
let res = this.malloc(size.bytes(), /*zero_init:*/ true)?;
let res = this.malloc(size.bytes(), AllocInit::Zero)?;
this.write_pointer(res, dest)?;
} else {
// On size overflow, return null and, on Unix, set errno.
Expand Down Expand Up @@ -509,6 +509,7 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(size),
Align::from_bytes(align).unwrap(),
memory_kind.into(),
AllocInit::Uninit
)?;

ecx.write_pointer(ptr, dest)
Expand Down Expand Up @@ -537,14 +538,8 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(size),
Align::from_bytes(align).unwrap(),
MiriMemoryKind::Rust.into(),
AllocInit::Zero
)?;

// We just allocated this, the access is definitely in-bounds.
this.write_bytes_ptr(
ptr.into(),
iter::repeat(0u8).take(usize::try_from(size).unwrap()),
)
.unwrap();
this.write_pointer(ptr, dest)
});
}
Expand Down Expand Up @@ -604,6 +599,7 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(new_size),
align,
MiriMemoryKind::Rust.into(),
AllocInit::Uninit
)?;
this.write_pointer(new_ptr, dest)
});
Expand Down
1 change: 1 addition & 0 deletions src/tools/miri/src/shims/unix/fs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(size),
dirent_layout.align.abi,
MiriMemoryKind::Runtime.into(),
AllocInit::Uninit
)?;
let entry: Pointer = entry.into();

Expand Down
10 changes: 1 addition & 9 deletions src/tools/miri/src/shims/unix/linux/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
Size::from_bytes(new_size),
align,
MiriMemoryKind::Mmap.into(),
AllocInit::Zero
)?;
if let Some(increase) = new_size.checked_sub(old_size) {
// We just allocated this, the access is definitely in-bounds and fits into our address space.
// mmap guarantees new mappings are zero-init.
this.write_bytes_ptr(
ptr.wrapping_offset(Size::from_bytes(old_size), this).into(),
std::iter::repeat(0u8).take(usize::try_from(increase).unwrap()),
)
.unwrap();
}

interp_ok(Scalar::from_pointer(ptr, this))
}
Expand Down
Loading
Loading