diff --git a/sway-ir/src/optimize/memcpyopt.rs b/sway-ir/src/optimize/memcpyopt.rs index 265db81695b..dba48da622e 100644 --- a/sway-ir/src/optimize/memcpyopt.rs +++ b/sway-ir/src/optimize/memcpyopt.rs @@ -31,7 +31,7 @@ pub fn mem_copy_opt( let mut modified = false; modified |= local_copy_prop_prememcpy(context, analyses, function)?; modified |= load_store_to_memcopy(context, function)?; - modified |= local_copy_prop(context, analyses, function)?; + //modified |= local_copy_prop(context, analyses, function)?; Ok(modified) } @@ -258,6 +258,75 @@ fn local_copy_prop_prememcpy( Ok(true) } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +enum InstCopy { + Load(Value), + MemCopyBytes { + dst_val_ptr: Value, + src_val_ptr: Value, + byte_len: u64, + }, + MemCopyVal { + dst_val_ptr: Value, + src_val_ptr: Value, + }, +} + +impl InstCopy { + fn from_value(context: &Context, value: &Value) -> Self { + match value.get_instruction(context) { + Some(Instruction { + op: InstOp::Load(v), + .. + }) => InstCopy::Load(*v), + Some(Instruction { + op: + InstOp::MemCopyVal { + dst_val_ptr, + src_val_ptr, + }, + .. + }) => InstCopy::MemCopyVal { + dst_val_ptr: *dst_val_ptr, + src_val_ptr: *src_val_ptr, + }, + Some(Instruction { + op: + InstOp::MemCopyBytes { + dst_val_ptr, + src_val_ptr, + byte_len, + }, + .. + }) => InstCopy::MemCopyBytes { + dst_val_ptr: *dst_val_ptr, + src_val_ptr: *src_val_ptr, + byte_len: *byte_len, + }, + _ => todo!(), + } + } + + fn deconstruct(&self, context: &Context) -> (Value, Value, u64) { + match self { + InstCopy::MemCopyBytes { + dst_val_ptr, + src_val_ptr, + byte_len, + } => (*dst_val_ptr, *src_val_ptr, *byte_len), + InstCopy::MemCopyVal { + dst_val_ptr, + src_val_ptr, + } => ( + *dst_val_ptr, + *src_val_ptr, + memory_utils::pointee_size(context, *dst_val_ptr), + ), + _ => unreachable!("Only memcpy instructions handled"), + } + } +} + /// Copy propagation of `memcpy`s within a block. fn local_copy_prop( context: &mut Context, @@ -274,67 +343,39 @@ fn local_copy_prop( }; // Currently (as we scan a block) available `memcpy`s. - let mut available_copies: FxHashSet; + let mut available_copies: FxHashSet; // Map a symbol to the available `memcpy`s of which it's a source. - let mut src_to_copies: FxIndexMap>; + let mut src_to_copies: FxIndexMap>; // Map a symbol to the available `memcpy`s of which it's a destination. // (multiple `memcpy`s for the same destination may be available when // they are partial / field writes, and don't alias). - let mut dest_to_copies: FxIndexMap>; + let mut dest_to_copies: FxIndexMap>; // If a value (symbol) is found to be defined, remove it from our tracking. fn kill_defined_symbol( context: &Context, value: Value, len: u64, - available_copies: &mut FxHashSet, - src_to_copies: &mut FxIndexMap>, - dest_to_copies: &mut FxIndexMap>, + available_copies: &mut FxHashSet, + src_to_copies: &mut FxIndexMap>, + dest_to_copies: &mut FxIndexMap>, ) { match get_referred_symbols(context, value) { ReferredSymbols::Complete(rs) => { for sym in rs { if let Some(copies) = src_to_copies.get_mut(&sym) { for copy in &*copies { - let (_, src_ptr, copy_size) = deconstruct_memcpy(context, *copy); - if memory_utils::may_alias(context, value, len, src_ptr, copy_size) { - available_copies.remove(copy); - } + available_copies.remove(copy); } - copies.retain(|copy| available_copies.contains(copy)); } if let Some(copies) = dest_to_copies.get_mut(&sym) { - for copy in &*copies { - let (dest_ptr, copy_size) = match copy.get_instruction(context).unwrap() - { - Instruction { - op: - InstOp::MemCopyBytes { - dst_val_ptr, - src_val_ptr: _, - byte_len, - }, - .. - } => (*dst_val_ptr, *byte_len), - Instruction { - op: - InstOp::MemCopyVal { - dst_val_ptr, - src_val_ptr: _, - }, - .. - } => ( - *dst_val_ptr, - memory_utils::pointee_size(context, *dst_val_ptr), - ), - _ => panic!("Unexpected copy instruction"), - }; - if memory_utils::may_alias(context, value, len, dest_ptr, copy_size) { - available_copies.remove(copy); - } + for op in &*copies { + available_copies.remove(op); } - copies.retain(|copy| available_copies.contains(copy)); } + + src_to_copies.swap_remove(&sym); + dest_to_copies.swap_remove(&sym); } } ReferredSymbols::Incomplete(_) => { @@ -353,9 +394,9 @@ fn local_copy_prop( copy_inst: Value, dst_val_ptr: Value, src_val_ptr: Value, - available_copies: &mut FxHashSet, - src_to_copies: &mut FxIndexMap>, - dest_to_copies: &mut FxIndexMap>, + available_copies: &mut FxHashSet, + src_to_copies: &mut FxIndexMap>, + dest_to_copies: &mut FxIndexMap>, ) { if let (Some(dst_sym), Some(src_sym)) = ( get_gep_symbol(context, dst_val_ptr), @@ -364,19 +405,17 @@ fn local_copy_prop( if escaped_symbols.contains(&dst_sym) || escaped_symbols.contains(&src_sym) { return; } + + let inst_cpy = InstCopy::from_value(context, ©_inst); dest_to_copies .entry(dst_sym) - .and_modify(|set| { - set.insert(copy_inst); - }) - .or_insert([copy_inst].into_iter().collect()); + .or_default() + .insert(inst_cpy.clone()); src_to_copies .entry(src_sym) - .and_modify(|set| { - set.insert(copy_inst); - }) - .or_insert([copy_inst].into_iter().collect()); - available_copies.insert(copy_inst); + .or_default() + .insert(inst_cpy.clone()); + available_copies.insert(inst_cpy); } } @@ -423,7 +462,7 @@ fn local_copy_prop( escaped_symbols: &FxHashSet, inst: Value, src_val_ptr: Value, - dest_to_copies: &FxIndexMap>, + dest_to_copies: &FxIndexMap>, replacements: &mut FxHashMap, ) -> bool { // For every `memcpy` that src_val_ptr is a destination of, @@ -437,8 +476,8 @@ fn local_copy_prop( .iter() .flat_map(|set| set.iter()) { - let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) = - deconstruct_memcpy(context, *memcpy); + let (dst_ptr_memcpy, src_ptr_memcpy, copy_len) = memcpy.deconstruct(context); + // If the location where we're loading from exactly matches the destination of // the memcpy, just load from the source pointer of the memcpy. // TODO: In both the arms below, we check that the pointer type @@ -514,9 +553,9 @@ fn local_copy_prop( fn kill_escape_args( context: &Context, args: &Vec, - available_copies: &mut FxHashSet, - src_to_copies: &mut FxIndexMap>, - dest_to_copies: &mut FxIndexMap>, + available_copies: &mut FxHashSet, + src_to_copies: &mut FxIndexMap>, + dest_to_copies: &mut FxIndexMap>, ) { for arg in args { match get_referred_symbols(context, *arg) { @@ -589,14 +628,19 @@ fn local_copy_prop( op: InstOp::Load(src_val_ptr), .. } => { - process_load( + if process_load( context, escaped_symbols, inst, *src_val_ptr, &dest_to_copies, &mut replacements, - ); + ) { + let cpy = InstCopy::from_value(context, &inst); + src_to_copies.retain(|_, v| !v.contains(&cpy)); + dest_to_copies.retain(|_, v| !v.contains(&cpy)); + available_copies.remove(&cpy); + } } Instruction { op: InstOp::MemCopyBytes { .. } | InstOp::MemCopyVal { .. }, @@ -631,6 +675,11 @@ fn local_copy_prop( &mut src_to_copies, &mut dest_to_copies, ); + } else { + let cpy = InstCopy::from_value(context, &inst); + src_to_copies.retain(|_, v| !v.contains(&cpy)); + dest_to_copies.retain(|_, v| !v.contains(&cpy)); + available_copies.remove(&cpy); } } Instruction { @@ -686,8 +735,10 @@ fn local_copy_prop( // going to be used in, we copy all the instructions into a new vec // and just replace the contents of the basic block. let mut new_insts = vec![]; - for inst in block.instruction_iter(context) { - if let Some(replacement) = replacements.remove(&inst) { + let mut replace_map: FxHashMap = FxHashMap::default(); + + for old_value in block.instruction_iter(context) { + let new_value = if let Some(replacement) = replacements.remove(&old_value) { let replacement = match replacement { Replacement::OldGep(v) => v, Replacement::NewGep(ReplGep { @@ -722,31 +773,52 @@ fn local_copy_prop( v } }; - match inst.get_instruction_mut(context) { + match old_value.get_instruction(context) { Some(Instruction { - op: InstOp::Load(ref mut src_val_ptr), + op: InstOp::Load(_), .. - }) - | Some(Instruction { + }) => Value::new_instruction(context, block, InstOp::Load(replacement)), + Some(Instruction { op: InstOp::MemCopyBytes { - ref mut src_val_ptr, + dst_val_ptr, + byte_len, .. }, .. - }) - | Some(Instruction { - op: - InstOp::MemCopyVal { - ref mut src_val_ptr, - .. - }, + }) => Value::new_instruction( + context, + block, + InstOp::MemCopyBytes { + dst_val_ptr: *dst_val_ptr, + src_val_ptr: replacement, + byte_len: *byte_len, + }, + ), + Some(Instruction { + op: InstOp::MemCopyVal { dst_val_ptr, .. }, .. - }) => *src_val_ptr = replacement, + }) => Value::new_instruction( + context, + block, + InstOp::MemCopyVal { + dst_val_ptr: *dst_val_ptr, + src_val_ptr: replacement, + }, + ), _ => panic!("Unexpected instruction type"), } + } else { + old_value + }; + + // Replace old instructions by its corresponding new ones + if new_value != old_value { + replace_map.insert(old_value, new_value); } - new_insts.push(inst); + new_value.replace_instruction_values(context, &replace_map); + + new_insts.push(new_value); } // Replace the basic block contents with what we just built.