Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply EarlyOtherwiseBranch to scalar value #129047

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 175 additions & 85 deletions compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {

let mut patch = MirPatch::new(body);

// create temp to store second discriminant in, `_s` in example above
let second_discriminant_temp =
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
// create temp to store second discriminant in, `_s` in example above
let second_discriminant_temp =
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);

patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
patch.add_statement(
parent_end,
StatementKind::StorageLive(second_discriminant_temp),
);

// create assignment of discriminant
patch.add_assign(
parent_end,
Place::from(second_discriminant_temp),
Rvalue::Discriminant(opt_data.child_place),
);
// create assignment of discriminant
patch.add_assign(
parent_end,
Place::from(second_discriminant_temp),
Rvalue::Discriminant(opt_data.child_place),
);
(
Some(second_discriminant_temp),
Operand::Move(Place::from(second_discriminant_temp)),
)
} else {
(None, Operand::Copy(opt_data.child_place))
};

// create temp to store inequality comparison between the two discriminants, `_t` in
// example above
Expand All @@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));

// create inequality comparison between the two discriminants
let comp_rvalue = Rvalue::BinaryOp(
nequal,
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
);
// create inequality comparison
let comp_rvalue =
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
patch.add_statement(
parent_end,
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
Expand Down Expand Up @@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
);

// generate StorageDead for the second_discriminant_temp not in use anymore
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
if let Some(second_discriminant_temp) = second_discriminant_temp {
// generate StorageDead for the second_discriminant_temp not in use anymore
patch.add_statement(
parent_end,
StatementKind::StorageDead(second_discriminant_temp),
);
}

// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
// the switch
Expand Down Expand Up @@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
child_place: Place<'tcx>,
child_ty: Ty<'tcx>,
child_source: SourceInfo,
need_hoist_discriminant: bool,
}

fn evaluate_candidate<'tcx>(
Expand All @@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
return None;
};
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
// of an invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
let (_, child) = targets.iter().next()?;
let child_terminator = &bbs[child].terminator();
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
&child_terminator.kind

let Terminator {
kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
source_info,
} = bbs[child].terminator()
else {
return None;
};
let child_ty = child_discr.ty(body.local_decls(), tcx);
if child_ty != parent_ty {
return None;
}
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {

// We only handle:
// ```
// bb4: {
// _8 = discriminant((_3.1: Enum1));
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
// }
// ```
// and
// ```
// bb2: {
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
// }
// ```
if bbs[child].statements.len() > 1 {
return None;
}

// When thie BB has exactly one statement, this statement should be discriminant.
let need_hoist_discriminant = bbs[child].statements.len() == 1;
let child_place = if need_hoist_discriminant {
if !bbs[targets.otherwise()].is_empty_unreachable() {
// Someone could write code like this:
// ```rust
// let Q = val;
// if discriminant(P) == otherwise {
// let ptr = &mut Q as *mut _ as *mut u8;
// // It may be difficult for us to effectively determine whether values are valid.
// // Invalid values can come from all sorts of corners.
// unsafe { *ptr = 10; }
// }
//
// match P {
// A => match Q {
// A => {
// // code
// }
// _ => {
// // don't use Q
// }
// }
// _ => {
// // don't use Q
// }
// };
// ```
//
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
// invalid value, which is UB.
// In order to fix this, **we would either need to show that the discriminant computation of
// `place` is computed in all branches**.
// FIXME(#95162) For the moment, we adopt a conservative approach and
// consider only the `otherwise` branch has no statements and an unreachable terminator.
return None;
}
// Handle:
// ```
// bb4: {
// _8 = discriminant((_3.1: Enum1));
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
// }
// ```
let [
Statement {
kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
..
},
] = bbs[child].statements.as_slice()
else {
return None;
};
*child_place
} else {
// Handle:
// ```
// bb2: {
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
// }
// ```
let Operand::Copy(child_place) = child_discr else {
return None;
};
*child_place
};
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
return None;
let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
{
child_targets.otherwise()
} else {
targets.otherwise()
};
let destination = child_targets.otherwise();

// Verify that the optimization is legal for each branch
for (value, child) in targets.iter() {
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
if !verify_candidate_branch(
&bbs[child],
value,
child_place,
destination,
need_hoist_discriminant,
) {
return None;
}
}
Some(OptimizationData {
destination,
child_place: *child_place,
child_place,
child_ty,
child_source: child_terminator.source_info,
child_source: *source_info,
need_hoist_discriminant,
})
}

Expand All @@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
value: u128,
place: Place<'tcx>,
destination: BasicBlock,
need_hoist_discriminant: bool,
) -> bool {
// In order for the optimization to be correct, the branch must...
// ...have exactly one statement
if let [statement] = branch.statements.as_slice()
// ...assign the discriminant of `place` in that statement
&& let StatementKind::Assign(boxed) = &statement.kind
&& let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed
&& *from_place == place
// ...make that assignment to a local
&& discr_place.projection.is_empty()
// ...terminate on a `SwitchInt` that invalidates that local
&& let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } =
&branch.terminator().kind
&& *switch_op == Operand::Move(*discr_place)
// ...fall through to `destination` if the switch misses
&& destination == targets.otherwise()
// ...have a branch for value `value`
&& let mut iter = targets.iter()
&& let Some((target_value, _)) = iter.next()
&& target_value == value
// ...and have no more branches
&& iter.next().is_none()
{
true
// In order for the optimization to be correct, the terminator must be a `SwitchInt`.
let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
return false;
};
if need_hoist_discriminant {
// If we need hoist discriminant, the branch must have exactly one statement.
let [statement] = branch.statements.as_slice() else {
return false;
};
// The statement must assign the discriminant of `place`.
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
statement.kind
else {
return false;
};
if from_place != place {
return false;
}
// The assignment must invalidate a local that terminate on a `SwitchInt`.
if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
return false;
}
} else {
false
// If we don't need hoist discriminant, the branch must not have any statements.
if !branch.statements.is_empty() {
return false;
}
// The place on `SwitchInt` must be the same.
if *switch_op != Operand::Copy(place) {
return false;
}
}
// It must fall through to `destination` if the switch misses.
if destination != targets.otherwise() {
return false;
}
// It must have exactly one branch for value `value` and have no more branches.
let mut iter = targets.iter();
let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
return false;
};
target_value == value
}
Loading
Loading