Skip to content

Commit

Permalink
Auto merge of rust-lang#128776 - Bryanskiy:deep-reject-ctxt, r=lcnr
Browse files Browse the repository at this point in the history
Use `DeepRejectCtxt` to quickly reject `ParamEnv` candidates

The description is on the [zulip thread](https://rust-lang.zulipchat.com/#narrow/stream/144729-t-types/topic/.5Basking.20for.20help.5D.20.60DeepRejectCtxt.60.20for.20param.20env.20candidates)

r? `@lcnr`
  • Loading branch information
bors committed Sep 6, 2024
2 parents 17b322f + c51953f commit 26b5599
Show file tree
Hide file tree
Showing 16 changed files with 272 additions and 152 deletions.
5 changes: 3 additions & 2 deletions compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ impl<'tcx> InherentCollect<'tcx> {
}
}

if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) {
if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer)
{
self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id);
} else {
bug!("unexpected self type: {:?}", self_ty);
Expand Down Expand Up @@ -132,7 +133,7 @@ impl<'tcx> InherentCollect<'tcx> {
}
}

if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::AsCandidateKey) {
if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::InstantiateWithInfer) {
self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id);
} else {
bug!("unexpected primitive type: {:?}", ty);
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/method/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
}

fn assemble_inherent_candidates_for_incoherent_ty(&mut self, self_ty: Ty<'tcx>) {
let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) else {
let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer) else {
bug!("unexpected incoherent type: {:?}", self_ty)
};
for &impl_def_id in self.tcx.incoherent_impls(simp).into_iter().flatten() {
Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_hir_typeck/src/method/suggest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2235,8 +2235,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let target_ty = self
.autoderef(sugg_span, rcvr_ty)
.find(|(rcvr_ty, _)| {
DeepRejectCtxt::new(self.tcx, TreatParams::ForLookup)
.types_may_unify(*rcvr_ty, impl_ty)
DeepRejectCtxt::relate_rigid_infer(self.tcx).types_may_unify(*rcvr_ty, impl_ty)
})
.map_or(impl_ty, |(ty, _)| ty)
.peel_refs();
Expand Down Expand Up @@ -2498,7 +2497,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
.into_iter()
.any(|info| self.associated_value(info.def_id, item_name).is_some());
let found_assoc = |ty: Ty<'tcx>| {
simplify_type(tcx, ty, TreatParams::AsCandidateKey)
simplify_type(tcx, ty, TreatParams::InstantiateWithInfer)
.and_then(|simp| {
tcx.incoherent_impls(simp)
.into_iter()
Expand Down Expand Up @@ -3963,7 +3962,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// cases where a positive bound implies a negative impl.
(candidates, Vec::new())
} else if let Some(simp_rcvr_ty) =
simplify_type(self.tcx, rcvr_ty, TreatParams::ForLookup)
simplify_type(self.tcx, rcvr_ty, TreatParams::AsRigid)
{
let mut potential_candidates = Vec::new();
let mut explicitly_negative = Vec::new();
Expand All @@ -3981,7 +3980,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
.any(|header| {
let imp = header.trait_ref.instantiate_identity();
let imp_simp =
simplify_type(self.tcx, imp.self_ty(), TreatParams::ForLookup);
simplify_type(self.tcx, imp.self_ty(), TreatParams::AsRigid);
imp_simp.is_some_and(|s| s == simp_rcvr_ty)
})
{
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_metadata/src/rmeta/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2033,7 +2033,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> {
let simplified_self_ty = fast_reject::simplify_type(
self.tcx,
trait_ref.self_ty(),
TreatParams::AsCandidateKey,
TreatParams::InstantiateWithInfer,
);
trait_impls
.entry(trait_ref.def_id)
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
let simp = ty::fast_reject::simplify_type(
tcx,
self_ty,
ty::fast_reject::TreatParams::ForLookup,
ty::fast_reject::TreatParams::AsRigid,
)
.unwrap();
consider_impls_for_simplified_type(simp);
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_middle/src/ty/fast_reject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ pub use rustc_type_ir::fast_reject::*;

use super::TyCtxt;

pub type DeepRejectCtxt<'tcx> = rustc_type_ir::fast_reject::DeepRejectCtxt<TyCtxt<'tcx>>;
pub type DeepRejectCtxt<
'tcx,
const INSTANTIATE_LHS_WITH_INFER: bool,
const INSTANTIATE_RHS_WITH_INFER: bool,
> = rustc_type_ir::fast_reject::DeepRejectCtxt<
TyCtxt<'tcx>,
INSTANTIATE_LHS_WITH_INFER,
INSTANTIATE_RHS_WITH_INFER,
>;

pub type SimplifiedType = rustc_type_ir::fast_reject::SimplifiedType<DefId>;
12 changes: 7 additions & 5 deletions compiler/rustc_middle/src/ty/trait_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ impl<'tcx> TyCtxt<'tcx> {
// whose outer level is not a parameter or projection. Especially for things like
// `T: Clone` this is incredibly useful as we would otherwise look at all the impls
// of `Clone` for `Option<T>`, `Vec<T>`, `ConcreteType` and so on.
// Note that we're using `TreatParams::ForLookup` to query `non_blanket_impls` while using
// `TreatParams::AsCandidateKey` while actually adding them.
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::ForLookup) {
// Note that we're using `TreatParams::AsRigid` to query `non_blanket_impls` while using
// `TreatParams::InstantiateWithInfer` while actually adding them.
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsRigid) {
if let Some(impls) = impls.non_blanket_impls.get(&simp) {
for &impl_def_id in impls {
f(impl_def_id);
Expand All @@ -190,7 +190,9 @@ impl<'tcx> TyCtxt<'tcx> {
self_ty: Ty<'tcx>,
) -> impl Iterator<Item = DefId> + 'tcx {
let impls = self.trait_impls_of(trait_def_id);
if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsCandidateKey) {
if let Some(simp) =
fast_reject::simplify_type(self, self_ty, TreatParams::InstantiateWithInfer)
{
if let Some(impls) = impls.non_blanket_impls.get(&simp) {
return impls.iter().copied();
}
Expand Down Expand Up @@ -239,7 +241,7 @@ pub(super) fn trait_impls_of_provider(tcx: TyCtxt<'_>, trait_id: DefId) -> Trait
let impl_self_ty = tcx.type_of(impl_def_id).instantiate_identity();

if let Some(simplified_self_ty) =
fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::InstantiateWithInfer)
{
impls.non_blanket_impls.entry(simplified_self_ty).or_default().push(impl_def_id);
} else {
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod inherent;
mod opaque_types;
mod weak_types;

use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::fast_reject::DeepRejectCtxt;
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::{self as ty, Interner, NormalizesTo, Upcast as _};
Expand Down Expand Up @@ -106,6 +106,12 @@ where
if let Some(projection_pred) = assumption.as_projection_clause() {
if projection_pred.projection_def_id() == goal.predicate.def_id() {
let cx = ecx.cx();
if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify(
goal.predicate.alias.args,
projection_pred.skip_binder().projection_term.args,
) {
return Err(NoSolution);
}
ecx.probe_trait_candidate(source).enter(|ecx| {
let assumption_projection_pred =
ecx.instantiate_binder_with_infer(projection_pred);
Expand Down Expand Up @@ -144,7 +150,7 @@ where

let goal_trait_ref = goal.predicate.alias.trait_ref(cx);
let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup).args_may_unify(
if !DeepRejectCtxt::relate_rigid_infer(ecx.cx()).args_may_unify(
goal.predicate.alias.trait_ref(cx).args,
impl_trait_ref.skip_binder().args,
) {
Expand Down
11 changes: 9 additions & 2 deletions compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use rustc_ast_ir::Movability;
use rustc_type_ir::data_structures::IndexSet;
use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::fast_reject::DeepRejectCtxt;
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _;
Expand Down Expand Up @@ -47,7 +47,7 @@ where
let cx = ecx.cx();

let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup)
if !DeepRejectCtxt::relate_rigid_infer(ecx.cx())
.args_may_unify(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args)
{
return Err(NoSolution);
Expand Down Expand Up @@ -124,6 +124,13 @@ where
if trait_clause.def_id() == goal.predicate.def_id()
&& trait_clause.polarity() == goal.predicate.polarity
{
if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify(
goal.predicate.trait_ref.args,
trait_clause.skip_binder().trait_ref.args,
) {
return Err(NoSolution);
}

ecx.probe_trait_candidate(source).enter(|ecx| {
let assumption_trait_pred = ecx.instantiate_binder_with_infer(trait_clause);
ecx.eq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::print::{FmtPrinter, Printer};
use rustc_middle::ty::{self, suggest_constraining_type_param, Ty};
use rustc_span::def_id::DefId;
Expand Down Expand Up @@ -317,7 +317,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand All @@ -338,7 +338,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.expected, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.expected, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand All @@ -358,7 +358,7 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(tcx, TreatParams::ForLookup)
if DeepRejectCtxt::relate_rigid_infer(tcx)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_middle::bug;
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::solve::{CandidateSource, Certainty, Goal};
use rustc_middle::traits::specialization_graph::OverlapMode;
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor};
use rustc_middle::ty::{self, Ty, TyCtxt};
pub use rustc_next_trait_solver::coherence::*;
Expand Down Expand Up @@ -96,7 +96,7 @@ pub fn overlapping_impls(
// Before doing expensive operations like entering an inference context, do
// a quick check via fast_reject to tell if the impl headers could possibly
// unify.
let drcx = DeepRejectCtxt::new(tcx, TreatParams::AsCandidateKey);
let drcx = DeepRejectCtxt::relate_infer_infer(tcx);
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
let may_overlap = match (impl1_ref, impl2_ref) {
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use rustc_infer::traits::ObligationCauseCode;
use rustc_middle::traits::select::OverflowError;
pub use rustc_middle::traits::Reveal;
use rustc_middle::traits::{BuiltinImplSource, ImplSource, ImplSourceUserDefinedData};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::visit::{MaxUniverse, TypeVisitable, TypeVisitableExt};
use rustc_middle::ty::{self, Term, Ty, TyCtxt, Upcast};
Expand Down Expand Up @@ -886,6 +887,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
potentially_unnormalized_candidates: bool,
) {
let infcx = selcx.infcx;
let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
for predicate in env_predicates {
let bound_predicate = predicate.kind();
if let ty::ClauseKind::Projection(data) = predicate.kind().skip_binder() {
Expand All @@ -894,6 +896,12 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
continue;
}

if !drcx
.args_may_unify(obligation.predicate.args, data.skip_binder().projection_term.args)
{
continue;
}

let is_match = infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use hir::LangItem;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_hir as hir;
use rustc_infer::traits::{Obligation, ObligationCause, PolyTraitObligation, SelectionError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::DeepRejectCtxt;
use rustc_middle::ty::{self, ToPolyTraitRef, Ty, TypeVisitableExt};
use rustc_middle::{bug, span_bug};
use tracing::{debug, instrument, trace};
Expand Down Expand Up @@ -248,11 +248,17 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
.filter(|p| p.def_id() == stack.obligation.predicate.def_id())
.filter(|p| p.polarity() == stack.obligation.predicate.polarity());

let drcx = DeepRejectCtxt::relate_rigid_rigid(self.tcx());
let obligation_args = stack.obligation.predicate.skip_binder().trait_ref.args;
// Keep only those bounds which may apply, and propagate overflow if it occurs.
for bound in bounds {
let bound_trait_ref = bound.map_bound(|t| t.trait_ref);
if !drcx.args_may_unify(obligation_args, bound_trait_ref.skip_binder().args) {
continue;
}
// FIXME(oli-obk): it is suspicious that we are dropping the constness and
// polarity here.
let wc = self.where_clause_may_apply(stack, bound.map_bound(|t| t.trait_ref))?;
let wc = self.where_clause_may_apply(stack, bound_trait_ref)?;
if wc.may_apply() {
candidates.vec.push(ParamCandidate(bound));
}
Expand Down Expand Up @@ -581,7 +587,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
return;
}

let drcx = DeepRejectCtxt::new(self.tcx(), TreatParams::ForLookup);
let drcx = DeepRejectCtxt::relate_rigid_infer(self.tcx());
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
self.tcx().for_each_relevant_impl(
obligation.predicate.def_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<'tcx> Children {
fn insert_blindly(&mut self, tcx: TyCtxt<'tcx>, impl_def_id: DefId) {
let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder();
if let Some(st) =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer)
{
debug!("insert_blindly: impl_def_id={:?} st={:?}", impl_def_id, st);
self.non_blanket_impls.entry(st).or_default().push(impl_def_id)
Expand All @@ -58,7 +58,7 @@ impl<'tcx> Children {
let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder();
let vec: &mut Vec<DefId>;
if let Some(st) =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey)
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer)
{
debug!("remove_existing: impl_def_id={:?} st={:?}", impl_def_id, st);
vec = self.non_blanket_impls.get_mut(&st).unwrap();
Expand Down Expand Up @@ -279,7 +279,7 @@ impl<'tcx> Graph {
let mut parent = trait_def_id;
let mut last_lint = None;
let simplified =
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey);
fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer);

// Descend the specialization tree, where `parent` is the current parent node.
loop {
Expand Down
Loading

0 comments on commit 26b5599

Please sign in to comment.