Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpret: adjust vtable validity check for higher-ranked types #135296

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
};
let erased_trait_ref =
ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
erased_trait_ref,
self.tcx.instantiate_bound_regions_with_erased(b)
)));
assert_eq!(
data_b.principal().map(|b| {
self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b)
}),
Some(erased_trait_ref),
);
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
Expand Down
40 changes: 1 addition & 39 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ use either::{Left, Right};
use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout};
use rustc_errors::DiagCtxtHandle;
use rustc_hir::def_id::DefId;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::infer::at::ToTrace;
use rustc_infer::traits::ObligationCause;
use rustc_middle::mir::interpret::{ErrorHandled, InvalidMetaKind, ReportedErrorInfo};
use rustc_middle::query::TyCtxtAt;
use rustc_middle::ty::layout::{
Expand All @@ -17,8 +14,7 @@ use rustc_middle::{mir, span_bug};
use rustc_session::Limit;
use rustc_span::Span;
use rustc_target::callconv::FnAbi;
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::{debug, instrument, trace};
use tracing::{debug, trace};

use super::{
Frame, FrameInfo, GlobalId, InterpErrorInfo, InterpErrorKind, InterpResult, MPlaceTy, Machine,
Expand Down Expand Up @@ -323,40 +319,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}

/// Check if the two things are equal in the current param_env, using an infcx to get proper
/// equality checks.
#[instrument(level = "trace", skip(self), ret)]
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
where
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
{
// Fast path: compare directly.
if a == b {
return true;
}
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env);
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let a = ocx.normalize(&cause, param_env, a);
let b = ocx.normalize(&cause, param_env, b);

if let Err(terr) = ocx.eq(&cause, param_env, a, b) {
trace!(?terr);
return false;
}

let errors = ocx.select_all_or_error();
if !errors.is_empty() {
trace!(?errors);
return false;
}

// All good.
true
}

/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
/// and is primarily intended for the panic machinery.
Expand Down
20 changes: 7 additions & 13 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,21 +86,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
}

// This checks whether there is a subtyping relation between the predicates in either direction.
// For example:
// - casting between `dyn for<'a> Trait<fn(&'a u8)>` and `dyn Trait<fn(&'static u8)>` is OK
// - casting between `dyn Trait<for<'a> fn(&'a u8)>` and either of the above is UB
for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
(
ty::ExistentialPredicate::Trait(a_data),
ty::ExistentialPredicate::Trait(b_data),
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
let a_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, a_pred);
let b_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b_pred);

(
ty::ExistentialPredicate::Projection(a_data),
ty::ExistentialPredicate::Projection(b_data),
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),

_ => false,
};
if !is_eq {
if a_pred != b_pred {
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
}
}
Expand Down
30 changes: 30 additions & 0 deletions src/tools/miri/tests/fail/validity/dyn-transmute-inner-binder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Test that transmuting from `&dyn Trait<fn(&'static ())>` to `&dyn Trait<for<'a> fn(&'a ())>` is UB.
//
// The vtable of `() as Trait<fn(&'static ())>` and `() as Trait<for<'a> fn(&'a ())>` can have
// different entries and, because in the former the entry for `foo` is vacant, this test will
// segfault at runtime.

trait Trait<U> {
fn foo(&self)
where
U: HigherRanked,
{
}
}
impl<T, U> Trait<U> for T {}

trait HigherRanked {}
impl HigherRanked for for<'a> fn(&'a ()) {}

// 2nd candidate is required so that selecting `(): Trait<fn(&'static ())>` will
// evaluate the candidates and fail the leak check instead of returning the
// only applicable candidate.
trait Unsatisfied {}
impl<T: Unsatisfied> HigherRanked for T {}

fn main() {
let x: &dyn Trait<fn(&'static ())> = &();
let y: &dyn Trait<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
//~^ ERROR: wrong trait in wide pointer vtable
y.foo();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
error: Undefined Behavior: constructing invalid value: wrong trait in wide pointer vtable: expected `Trait<for<'a> fn(&'a ())>`, but encountered `Trait<fn(&())>`
--> tests/fail/validity/dyn-transmute-inner-binder.rs:LL:CC
|
LL | let y: &dyn Trait<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
| ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: wrong trait in wide pointer vtable: expected `Trait<for<'a> fn(&'a ())>`, but encountered `Trait<fn(&())>`
|
= help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
= help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information
= note: BACKTRACE:
= note: inside `main` at tests/fail/validity/dyn-transmute-inner-binder.rs:LL:CC

note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace

error: aborting due to 1 previous error

30 changes: 30 additions & 0 deletions src/tools/miri/tests/pass/dyn-upcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ fn main() {
drop_principal();
modulo_binder();
modulo_assoc();
bidirectional_subtyping();
}

fn vtable_nop_cast() {
Expand Down Expand Up @@ -534,3 +535,32 @@ fn modulo_assoc() {

(&() as &dyn Trait as &dyn Middle<()>).say_hello(&0);
}

fn bidirectional_subtyping() {
// Test that transmuting between subtypes of dyn traits is fine, even in the
// "wrong direction", i.e. going from a lower-ranked to a higher-ranked dyn trait.
// Note that compared to the `dyn-transmute-inner-binder` test, the `for` is on the
// *outside* here!

trait Trait<U: ?Sized> {}
impl<T, U: ?Sized> Trait<U> for T {}

struct Wrapper<T: ?Sized>(T);

let x: &dyn Trait<fn(&'static ())> = &();
let _y: &dyn for<'a> Trait<fn(&'a ())> = unsafe { std::mem::transmute(x) };

let x: &dyn for<'a> Trait<fn(&'a ())> = &();
let _y: &dyn Trait<fn(&'static ())> = unsafe { std::mem::transmute(x) };

let x: &dyn Trait<dyn Trait<fn(&'static ())>> = &();
let _y: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = unsafe { std::mem::transmute(x) };

let x: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = &();
let _y: &dyn Trait<dyn Trait<fn(&'static ())>> = unsafe { std::mem::transmute(x) };

// This lowers to a ptr-to-ptr cast (which behaves like a transmute)
// and not an unsizing coercion:
let x: *const dyn for<'a> Trait<&'a ()> = &();
let _y: *const Wrapper<dyn Trait<&'static ()>> = x as _;
}
Loading