diff --git a/miasm/expression/simplifications_common.py b/miasm/expression/simplifications_common.py index 9156ee671..d24f010ed 100644 --- a/miasm/expression/simplifications_common.py +++ b/miasm/expression/simplifications_common.py @@ -1753,19 +1753,39 @@ def simp_compose_and_mask(_, expr): if not arg2.is_int(): return expr int2 = int(arg2) - if (int2 + 1) & int2 != 0: - return expr - mask_size = int2.bit_length() + 7 // 8 + mask_size = (int2.bit_length() + 7) // 8 * 8 + if int2 == int(arg1.mask): + return arg1 out = [] + mask_needed = False for offset, arg in arg1.iter_args(): if offset == mask_size: - return ExprCompose(*out).zeroExtend(expr.size) - elif mask_size > offset and mask_size < offset+arg.size and arg.is_int(): - out.append(ExprSlice(arg, 0, mask_size-offset)) - return ExprCompose(*out).zeroExtend(expr.size) + break else: - out.append(arg) - return expr + if arg.is_int() and offset < mask_size < offset+arg.size: + arg = ExprSlice(arg, 0, mask_size-offset) + + arg_mask = (int(arg.mask) << offset) + if int2 & arg_mask != 0: + out.append(arg) + if int2 & arg_mask != arg_mask: + mask_needed = True + elif mask_size > offset + arg.size: + out.append(ExprInt(0, arg.size)) + + if mask_size <= offset + arg.size: + break + + if len(out) == 0: + return ExprInt(0, expr.size) + else: + size = sum(arg.size for arg in out) + if size != expr.size: + out.append(ExprInt(0, expr.size - size)) + result = ExprCompose(*out) + if mask_needed: + result = result & arg2 + return result def simp_bcdadd_cf(_, expr): """bcdadd(const, const) => decimal""" diff --git a/test/expression/simplifications.py b/test/expression/simplifications.py index 96ab8a59e..c469d1c2f 100644 --- a/test/expression/simplifications.py +++ b/test/expression/simplifications.py @@ -368,6 +368,19 @@ def check(expr_in, expr_out): ExprInt(0x1, 32), ExprInt(0x0, 32)) ), + (ExprCompose(a[:8],b[:8],c[:8],d[:8]) + & + ExprInt(0xA000B000, 32), + ExprCompose(ExprInt(0,8), b[:8], ExprInt(0,8), d[:8]) & + ExprInt(0xA000B000, 32) + ), + + (ExprCompose(a[:8],b[:8],c[:8],d[:8]) + & + ExprInt(0xFF00FF00, 32), + ExprCompose(ExprInt(0,8), b[:8], ExprInt(0,8), d[:8]) + ), + (ExprCompose(a[:16], b[:16])[8:32], ExprCompose(a[8:16], b[:16])), ((a >> ExprInt(16, 32))[:16],