Skip to content

Commit

Permalink
get in aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
Medowhill committed May 30, 2024
1 parent 372f4cf commit d43b286
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 77 deletions.
3 changes: 1 addition & 2 deletions src/compile_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,11 @@ pub fn span_to_path(span: Span, source_map: &SourceMap) -> Option<PathBuf> {

pub type Suggestions = HashMap<PathBuf, Vec<Suggestion>>;

pub fn apply_suggestions(suggestions: &mut Suggestions) {
pub fn apply_suggestions(suggestions: &Suggestions) {
for (path, suggestions) in suggestions {
if suggestions.is_empty() {
continue;
}
suggestions.sort_by_key(|s| s.snippets[0].range.start);
let code = String::from_utf8(fs::read(path).unwrap()).unwrap();
let fixed = rustfix::apply_suggestions(&code, suggestions).unwrap();
fs::write(path, fixed.as_bytes()).unwrap();
Expand Down
11 changes: 7 additions & 4 deletions src/must_analysis/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,18 @@ impl Analyzer<'_, '_, '_> {

while let Some(location) = work_list.pop() {
let state = states.get(&location).unwrap_or(&bot);
let nexts = self.body.stmt_at(location).either(
let (nexts, is_call) = self.body.stmt_at(location).either(
|stmt| {
let mut next_state = state.clone();
self.transfer_stmt(stmt, location, &mut next_state);
vec![(location.successor_within_block(), next_state)]
(vec![(location.successor_within_block(), next_state)], false)
},
|terminator| {
let v = self.discriminant_values.get(&location.block);
self.transfer_term(terminator, v, location, state)
(
self.transfer_term(terminator, v, location, state),
matches!(terminator.kind, TerminatorKind::Call { .. }),
)
},
);
// println!("{:?}", state);
Expand All @@ -160,7 +163,7 @@ impl Analyzer<'_, '_, '_> {
// println!("{:?}", nexts);
// println!("-----------------");
for (next_location, new_next_state) in nexts {
if self.join_terminators.contains(&location) {
if self.join_terminators.contains(&location) || is_call {
let out_state = out_states.get(&location).unwrap_or(&bot);
let joined = out_state.join(&new_next_state);
out_states.insert(location, joined);
Expand Down
160 changes: 89 additions & 71 deletions src/tag_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,18 @@ pub fn analyze(tcx: TyCtxt<'_>, conf: &Config) {
visitor.visit_body(body);
let mut hvisitor = HBodyVisitor::new(tcx);
hvisitor.visit_body(hbody);
if visitor.accesses.is_empty()
&& visitor.struct_accesses.is_empty()
&& visitor.aggregates.is_empty()
{
continue;
}
basic_blocks.insert(local_def_id, visitor.basic_blocks);
let locals = locals.entry(local_def_id).or_default();
for (local, local_def) in body.local_decls.iter_enumerated() {
let hir_id = some_or!(hvisitor.bindings.get(&local_def.source_info.span), continue);
locals.entry(*hir_id).or_insert(local);
}
if visitor.accesses.is_empty()
&& visitor.struct_accesses.is_empty()
&& visitor.aggregates.is_empty()
{
continue;
}
if !hvisitor.inits.is_empty() {
for (local, location) in &visitor.inits {
let span = body
Expand Down Expand Up @@ -810,14 +810,14 @@ impl {} {{

let mut match_targets = 0;
let mut if_targets = 0;
let mut aggregates_num = 0;
for item_id in hir.items() {
let item = hir.item(item_id);
let (ItemKind::Fn(_, _, body_id) | ItemKind::Static(_, _, body_id)) = item.kind else {
continue;
};
let hir_body = hir.body(body_id);
let local_def_id = item_id.owner_id.def_id;
let basic_blocks = some_or!(basic_blocks.get(&local_def_id), continue);
let typeck = tcx.typeck(local_def_id);
let mut visitor = SuggestingVisitor {
tcx,
Expand All @@ -829,34 +829,38 @@ impl {} {{
unions: &tagged_unions,
access_in_matches: &access_in_matches,
access_in_ifs: &access_in_ifs,
basic_blocks,
basic_blocks: &basic_blocks[&local_def_id],
hir_id_to_locals: &locals[&local_def_id],
suggestions: &mut suggestions,

locals: HashMap::new(),
match_targets: HashMap::new(),
if_targets: HashMap::new(),
aggregates_num: 0,
aggregate_spans: vec![],
};
visitor.visit_body(hir_body);

match_targets += visitor.match_targets.len();
if_targets += visitor.if_targets.len();
aggregates_num += visitor.aggregates_num;
}

println!("match_targets: {}", match_targets);
println!("if_targets: {}", if_targets);
println!("aggregates_num: {}", aggregates_num);

let mut suggestions = suggestions.suggestions;
for (path, suggestions) in &suggestions {
for (path, suggestions) in &mut suggestions {
tracing::info!("{:?}", path);
suggestions.sort_by_key(|s| s.snippets[0].range.start);
for suggestion in suggestions {
tracing::info!("{:?}", suggestion);
}
}

if conf.transform {
compile_util::apply_suggestions(&mut suggestions);
compile_util::apply_suggestions(&suggestions);
}
}

Expand Down Expand Up @@ -1237,16 +1241,20 @@ impl<'tcx> MVisitor<'tcx> for MBodyVisitor<'tcx, '_> {
}

fn visit_basic_block_data(&mut self, block: BasicBlock, data: &BasicBlockData<'tcx>) {
if let Some(stmt) = data.statements.get(0) {
let span = stmt.source_info.span;
let mut lo = span.lo();
let mut hi = span.hi();
for stmt in &data.statements[1..] {
let span = stmt.source_info.span;
lo = lo.min(span.lo());
hi = hi.max(span.hi());
}
let span = span.with_lo(lo).with_hi(hi);
let spans = data.statements.iter().map(|stmt| stmt.source_info.span);
let term = data.terminator();
let spans: Box<dyn Iterator<Item = Span>> =
if matches!(term.kind, TerminatorKind::Call { .. }) {
Box::new(spans.chain(std::iter::once(term.source_info.span)))
} else {
Box::new(spans)
};
if let Some(span) = spans.reduce(|span1, span2| {
let lo = span1.lo().min(span2.lo());
let hi = span1.hi().max(span2.hi());
span1.with_lo(lo).with_hi(hi)
}) {
let span = span.with_hi(span.hi() + BytePos(1));
let location = Location {
block,
statement_index: data.statements.len(),
Expand Down Expand Up @@ -1339,6 +1347,7 @@ struct SuggestingVisitor<'a, 'tcx> {
locals: HashMap<HirId, &'tcx Expr<'tcx>>,
match_targets: HashMap<Span, String>,
if_targets: HashMap<Span, String>,
aggregates_num: usize,
aggregate_spans: Vec<Span>,
}

Expand Down Expand Up @@ -1581,17 +1590,17 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
self.suggestions.add(expr.span, "(*__v)".to_string());
} else {
let (ctx, _) = get_expr_context(expr, self.tcx);
match ctx {
ExprContext::Value => {
let call = format!("get_{}()", field.name);
self.suggestions.add(field.span, call);
}
ExprContext::Store(_) | ExprContext::Address => {
if !self
.aggregate_spans
.iter()
.any(|span| span.contains(expr.span))
{
if !self
.aggregate_spans
.iter()
.any(|span| span.contains(expr.span))
{
match ctx {
ExprContext::Value => {
let call = format!("get_{}()", field.name);
self.suggestions.add(field.span, call);
}
ExprContext::Store(_) | ExprContext::Address => {
let span = expr.span.shrink_to_lo();
self.suggestions.add(span, "(*".to_string());

Expand Down Expand Up @@ -1796,14 +1805,15 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
let (variant, value) = vs.into_iter().next().unwrap();
if let ExprKind::Path(QPath::Resolved(_, path)) = root.kind {
let Res::Local(hir_id) = path.res else { continue };
let local = self.hir_id_to_locals[&hir_id];
let local = *some_or!(self.hir_id_to_locals.get(&hir_id), continue);
let field_at = FieldAt {
func: self.func,
location,
local,
field: ts.tag_index,
};
let tags = some_or!(self.field_values.get(&field_at), continue);
let tags = self.field_values.get(&field_at);
let tags = some_or!(tags, continue);
if tags.len() != 1 {
continue;
}
Expand All @@ -1814,6 +1824,7 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
.filter(|bs| &bs.fields[0] == union_field_name)
.collect();
let bs = bss.pop().unwrap();
let value = self.assigned_value_to_string(&value);
let code = format!(
"{}.{} = {}::{}{}({});",
self.tcx
Expand Down Expand Up @@ -1848,6 +1859,7 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
if removed {
self.suggestions.add(tag_assign.span, "".to_string());
self.aggregate_spans.push(tag_assign.span);
self.aggregates_num += 1;
}
}
}
Expand Down Expand Up @@ -1877,26 +1889,60 @@ impl<'tcx> SuggestingVisitor<'_, 'tcx> {
None
}
}

fn assigned_value_to_string(&self, assigned_value: &AssignedValue<'tcx>) -> String {
let source_map = self.tcx.sess.source_map();
match assigned_value {
AssignedValue::Compound(name, fields) => {
let mut s = name.clone();
s.push_str(" { ");
for (i, (field, value)) in fields.iter().enumerate() {
if i > 0 {
s.push_str(", ");
}
s.push_str(field);
let value = self.assigned_value_to_string(value);
if field != &value {
write!(&mut s, ": {}", value).unwrap();
}
}
s.push_str(" }");
s
}
AssignedValue::Primitive(value) => {
if let ExprKind::Field(e, f) = value.kind {
let ty = self.typeck.expr_ty(e);
if let TyKind::Adt(adt_def, _) = ty.kind() {
if let Some(did) = adt_def.did().as_local() {
if self.unions.contains_key(&did) {
let e = source_map.span_to_snippet(e.span).unwrap();
return format!("{}.get_{}()", e, f.name);
}
}
}
}
source_map.span_to_snippet(value.span).unwrap()
}
}
}
}

fn make_aggregate<'tcx>(
ty: Option<LocalDefId>,
fields: &mut Vec<String>,
assigns: &HashMap<Vec<String>, &AssignBlockStmt<'tcx>>,
tcx: TyCtxt<'tcx>,
) -> Option<AssignedValue> {
) -> Option<AssignedValue<'tcx>> {
if let Some(bs) = assigns.get(fields) {
tcx.sess
.source_map()
.span_to_snippet(bs.rhs.span)
.ok()
.map(AssignedValue::Primitive)
Some(AssignedValue::Primitive(bs.rhs))
} else {
let ty = ty?;
let def_path = tcx.def_path(ty.to_def_id());
let mut name = "crate".to_string();
for data in def_path.data {
write!(name, "::{}", data).unwrap();
let data = format!("{}", data);
let escape = if data == "async" { "r#" } else { "" };
write!(name, "::{}{}", escape, data).unwrap();
}
let adt_def = tcx.adt_def(ty);
let (ItemKind::Struct(vd, _) | ItemKind::Union(vd, _)) = tcx.hir().expect_item(ty).kind
Expand Down Expand Up @@ -1935,37 +1981,9 @@ fn make_aggregate<'tcx>(
}
}

enum AssignedValue {
Compound(String, HashMap<String, AssignedValue>),
Primitive(String),
}

impl std::fmt::Debug for AssignedValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AssignedValue::Compound(name, fields) => {
write!(f, "{} {{ ", name)?;
for (i, (field, value)) in fields.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", field)?;
let value = value.to_string();
if field != &value {
write!(f, ": {}", value)?;
}
}
write!(f, " }}")
}
AssignedValue::Primitive(value) => write!(f, "{}", value),
}
}
}

impl std::fmt::Display for AssignedValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
enum AssignedValue<'tcx> {
Compound(String, HashMap<String, AssignedValue<'tcx>>),
Primitive(&'tcx Expr<'tcx>),
}

#[derive(Debug, Default)]
Expand Down

0 comments on commit d43b286

Please sign in to comment.