Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trying o fix local_copy_prop optimization #6800

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 149 additions & 77 deletions sway-ir/src/optimize/memcpyopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn mem_copy_opt(
let mut modified = false;
modified |= local_copy_prop_prememcpy(context, analyses, function)?;
modified |= load_store_to_memcopy(context, function)?;
modified |= local_copy_prop(context, analyses, function)?;
//modified |= local_copy_prop(context, analyses, function)?;

Ok(modified)
}
Expand Down Expand Up @@ -258,6 +258,75 @@ fn local_copy_prop_prememcpy(
Ok(true)
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
enum InstCopy {
Load(Value),
MemCopyBytes {
dst_val_ptr: Value,
src_val_ptr: Value,
byte_len: u64,
},
MemCopyVal {
dst_val_ptr: Value,
src_val_ptr: Value,
},
}

impl InstCopy {
fn from_value(context: &Context, value: &Value) -> Self {
match value.get_instruction(context) {
Some(Instruction {
op: InstOp::Load(v),
..
}) => InstCopy::Load(*v),
Some(Instruction {
op:
InstOp::MemCopyVal {
dst_val_ptr,
src_val_ptr,
},
..
}) => InstCopy::MemCopyVal {
dst_val_ptr: *dst_val_ptr,
src_val_ptr: *src_val_ptr,
},
Some(Instruction {
op:
InstOp::MemCopyBytes {
dst_val_ptr,
src_val_ptr,
byte_len,
},
..
}) => InstCopy::MemCopyBytes {
dst_val_ptr: *dst_val_ptr,
src_val_ptr: *src_val_ptr,
byte_len: *byte_len,
},
_ => todo!(),
}
}

fn deconstruct(&self, context: &Context) -> (Value, Value, u64) {
match self {
InstCopy::MemCopyBytes {
dst_val_ptr,
src_val_ptr,
byte_len,
} => (*dst_val_ptr, *src_val_ptr, *byte_len),
InstCopy::MemCopyVal {
dst_val_ptr,
src_val_ptr,
} => (
*dst_val_ptr,
*src_val_ptr,
memory_utils::pointee_size(context, *dst_val_ptr),
),
_ => unreachable!("Only memcpy instructions handled"),
}
}
}

/// Copy propagation of `memcpy`s within a block.
fn local_copy_prop(
context: &mut Context,
Expand All @@ -274,67 +343,39 @@ fn local_copy_prop(
};

// Currently (as we scan a block) available `memcpy`s.
let mut available_copies: FxHashSet<Value>;
let mut available_copies: FxHashSet<InstCopy>;
// Map a symbol to the available `memcpy`s of which it's a source.
let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
let mut src_to_copies: FxIndexMap<Symbol, FxIndexSet<InstCopy>>;
// Map a symbol to the available `memcpy`s of which it's a destination.
// (multiple `memcpy`s for the same destination may be available when
// they are partial / field writes, and don't alias).
let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<Value>>;
let mut dest_to_copies: FxIndexMap<Symbol, FxIndexSet<InstCopy>>;

// If a value (symbol) is found to be defined, remove it from our tracking.
fn kill_defined_symbol(
context: &Context,
value: Value,
len: u64,
available_copies: &mut FxHashSet<Value>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
available_copies: &mut FxHashSet<InstCopy>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
) {
match get_referred_symbols(context, value) {
ReferredSymbols::Complete(rs) => {
for sym in rs {
if let Some(copies) = src_to_copies.get_mut(&sym) {
for copy in &*copies {
let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy);
if memory_utils::may_alias(context, value, len, src_ptr, copy_size) {
available_copies.remove(copy);
}
available_copies.remove(copy);
}
copies.retain(|copy| available_copies.contains(copy));
}
if let Some(copies) = dest_to_copies.get_mut(&sym) {
for copy in &*copies {
let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap()
{
Instruction {
op:
InstOp::MemCopyBytes {
dst_val_ptr,
src_val_ptr: _,
byte_len,
},
..
} => (*dst_val_ptr, *byte_len),
Instruction {
op:
InstOp::MemCopyVal {
dst_val_ptr,
src_val_ptr: _,
},
..
} => (
*dst_val_ptr,
memory_utils::pointee_size(context, *dst_val_ptr),
),
_ => panic!("Unexpected copy instruction"),
};
if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) {
available_copies.remove(copy);
}
for op in &*copies {
available_copies.remove(op);
}
copies.retain(|copy| available_copies.contains(copy));
}

src_to_copies.swap_remove(&sym);
dest_to_copies.swap_remove(&sym);
}
}
ReferredSymbols::Incomplete(_) => {
Expand All @@ -353,9 +394,9 @@ fn local_copy_prop(
copy_inst: Value,
dst_val_ptr: Value,
src_val_ptr: Value,
available_copies: &mut FxHashSet<Value>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
available_copies: &mut FxHashSet<InstCopy>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
) {
if let (Some(dst_sym), Some(src_sym)) = (
get_gep_symbol(context, dst_val_ptr),
Expand All @@ -364,19 +405,17 @@ fn local_copy_prop(
if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) {
return;
}

let inst_cpy = InstCopy::from_value(context, &copy_inst);
dest_to_copies
.entry(dst_sym)
.and_modify(|set| {
set.insert(copy_inst);
})
.or_insert([copy_inst].into_iter().collect());
.or_default()
.insert(inst_cpy.clone());
src_to_copies
.entry(src_sym)
.and_modify(|set| {
set.insert(copy_inst);
})
.or_insert([copy_inst].into_iter().collect());
available_copies.insert(copy_inst);
.or_default()
.insert(inst_cpy.clone());
available_copies.insert(inst_cpy);
}
}

Expand Down Expand Up @@ -423,7 +462,7 @@ fn local_copy_prop(
escaped_symbols: &FxHashSet<Symbol>,
inst: Value,
src_val_ptr: Value,
dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<Value>>,
dest_to_copies: &FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
replacements: &mut FxHashMap<Value, Replacement>,
) -> bool {
// For every `memcpy` that src_val_ptr is a destination of,
Expand All @@ -437,8 +476,8 @@ fn local_copy_prop(
.iter()
.flat_map(|set| set.iter())
{
let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) =
deconstruct_memcpy(context, *memcpy);
let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) = memcpy.deconstruct(context);

// If the location where we're loading from exactly matches the destination of
// the memcpy, just load from the source pointer of the memcpy.
// TODO: In both the arms below, we check that the pointer type
Expand Down Expand Up @@ -514,9 +553,9 @@ fn local_copy_prop(
fn kill_escape_args(
context: &Context,
args: &Vec<Value>,
available_copies: &mut FxHashSet<Value>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<Value>>,
available_copies: &mut FxHashSet<InstCopy>,
src_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
dest_to_copies: &mut FxIndexMap<Symbol, FxIndexSet<InstCopy>>,
) {
for arg in args {
match get_referred_symbols(context, *arg) {
Expand Down Expand Up @@ -589,14 +628,19 @@ fn local_copy_prop(
op: InstOp::Load(src_val_ptr),
..
} => {
process_load(
if process_load(
context,
escaped_symbols,
inst,
*src_val_ptr,
&dest_to_copies,
&mut replacements,
);
) {
let cpy = InstCopy::from_value(context, &inst);
src_to_copies.retain(|_, v| !v.contains(&cpy));
dest_to_copies.retain(|_, v| !v.contains(&cpy));
available_copies.remove(&cpy);
}
}
Instruction {
op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. },
Expand Down Expand Up @@ -631,6 +675,11 @@ fn local_copy_prop(
&mut src_to_copies,
&mut dest_to_copies,
);
} else {
let cpy = InstCopy::from_value(context, &inst);
src_to_copies.retain(|_, v| !v.contains(&cpy));
dest_to_copies.retain(|_, v| !v.contains(&cpy));
available_copies.remove(&cpy);
}
}
Instruction {
Expand Down Expand Up @@ -686,8 +735,10 @@ fn local_copy_prop(
// going to be used in, we copy all the instructions into a new vec
// and just replace the contents of the basic block.
let mut new_insts = vec![];
for inst in block.instruction_iter(context) {
if let Some(replacement) = replacements.remove(&inst) {
let mut replace_map: FxHashMap<Value, Value> = FxHashMap::default();

for old_value in block.instruction_iter(context) {
let new_value = if let Some(replacement) = replacements.remove(&old_value) {
let replacement = match replacement {
Replacement::OldGep(v) => v,
Replacement::NewGep(ReplGep {
Expand Down Expand Up @@ -722,31 +773,52 @@ fn local_copy_prop(
v
}
};
match inst.get_instruction_mut(context) {
match old_value.get_instruction(context) {
Some(Instruction {
op: InstOp::Load(ref mut src_val_ptr),
op: InstOp::Load(_),
..
})
| Some(Instruction {
}) => Value::new_instruction(context, block, InstOp::Load(replacement)),
Some(Instruction {
op:
InstOp::MemCopyBytes {
ref mut src_val_ptr,
dst_val_ptr,
byte_len,
..
},
..
})
| Some(Instruction {
op:
InstOp::MemCopyVal {
ref mut src_val_ptr,
..
},
}) => Value::new_instruction(
context,
block,
InstOp::MemCopyBytes {
dst_val_ptr: *dst_val_ptr,
src_val_ptr: replacement,
byte_len: *byte_len,
},
),
Some(Instruction {
op: InstOp::MemCopyVal { dst_val_ptr, .. },
..
}) => *src_val_ptr = replacement,
}) => Value::new_instruction(
context,
block,
InstOp::MemCopyVal {
dst_val_ptr: *dst_val_ptr,
src_val_ptr: replacement,
},
),
_ => panic!("Unexpected instruction type"),
}
} else {
old_value
};

// Replace old instructions by its corresponding new ones
if new_value != old_value {
replace_map.insert(old_value, new_value);
}
new_insts.push(inst);
new_value.replace_instruction_values(context, &replace_map);

new_insts.push(new_value);
}

// Replace the basic block contents with what we just built.
Expand Down
Loading