Skip to content

Commit

Permalink
Merge pull request #871 from serpilliere/simp_mult
Browse files Browse the repository at this point in the history
Simple: add multiplication simplification
  • Loading branch information
commial authored Nov 12, 2018
2 parents ea9faf2 + 5e9ef3b commit b7e9a81
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
1 change: 1 addition & 0 deletions miasm2/expression/simplifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ExpressionSimplifier(object):
simplifications_common.simp_cst_propagation,
simplifications_common.simp_cond_op_int,
simplifications_common.simp_cond_factor,
simplifications_common.simp_add_multiple,
# CC op
simplifications_common.simp_cc_conds,
simplifications_common.simp_subwc_cf,
Expand Down
75 changes: 75 additions & 0 deletions miasm2/expression/simplifications_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,12 @@ def simp_cst_propagation(e_s, expr):
return -ExprOp(op_name, *new_args)
args = new_args

# -(a * b * int) => a * b * (-int)
if op_name == "-" and args[0].is_op('*') and args[0].args[-1].is_int():
args = args[0].args
return ExprOp('*', *(list(args[:-1]) + [ExprInt(-int(args[-1]), expr.size)]))


# A << int with A ExprCompose => move index
if (op_name == "<<" and args[0].is_compose() and
args[1].is_int() and int(args[1]) != 0):
Expand Down Expand Up @@ -1138,3 +1144,72 @@ def simp_slice_of_ext(expr_s, expr):
if arg.size != expr.size:
return expr
return arg

def simp_add_multiple(expr_s, expr):
# X + X => 2 * X
# X + X * int1 => X * (1 + int1)
# X * int1 + (- X) => X * (int1 - 1)
# X + (X << int1) => X * (1 + 2 ** int1)
# Correct even if addition overflow/underflow
if not expr.is_op('+'):
return expr

# Extract each argument and its counter
operands = {}
for i, arg in enumerate(expr.args):
if arg.is_op('*') and arg.args[1].is_int():
base_expr, factor = arg.args
operands[base_expr] = operands.get(base_expr, 0) + int(factor)
elif arg.is_op('<<') and arg.args[1].is_int():
base_expr, factor = arg.args
operands[base_expr] = operands.get(base_expr, 0) + 2 ** int(factor)
elif arg.is_op("-"):
arg = arg.args[0]
if arg.is_op('<<') and arg.args[1].is_int():
base_expr, factor = arg.args
operands[base_expr] = operands.get(base_expr, 0) - (2 ** int(factor))
else:
operands[arg] = operands.get(arg, 0) - 1
else:
operands[arg] = operands.get(arg, 0) + 1
out = []

# Best effort to factor common args:
# (a + b) * 3 + a + b => (a + b) * 4
# Does not factor:
# (a + b) * 3 + 2 * a + b => (a + b) * 4 + a
modified = True
while modified:
modified = False
for arg, count in operands.iteritems():
if not arg.is_op('+'):
continue
components = arg.args
if not all(component in operands for component in components):
continue
counters = set(operands[component] for component in components)
if len(counters) != 1:
continue
counter = counters.pop()
for component in components:
del operands[component]
operands[arg] += counter
modified = True
break

for arg, count in operands.iteritems():
if count == 0:
continue
if count == 1:
out.append(arg)
continue
out.append(arg * ExprInt(count, expr.size))

if len(out) == len(expr.args):
# No reductions
return expr
if not out:
return ExprInt(0, expr.size)
if len(out) == 1:
return out[0]
return ExprOp('+', *out)
20 changes: 19 additions & 1 deletion test/expression/simplifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def check(expr_in, expr_out):
i0 = ExprInt(0, 32)
i1 = ExprInt(1, 32)
i2 = ExprInt(2, 32)
i3 = ExprInt(3, 32)
im1 = ExprInt(-1, 32)
im2 = ExprInt(-2, 32)

icustom = ExprInt(0x12345678, 32)
cc = ExprCond(a, b, c)

Expand Down Expand Up @@ -242,7 +245,7 @@ def check(expr_in, expr_out):
(ExprOp('*', -a, -b, c, ExprInt(0x12, 32)),
ExprOp('*', a, b, c, ExprInt(0x12, 32))),
(ExprOp('*', -a, -b, -c, ExprInt(0x12, 32)),
- ExprOp('*', a, b, c, ExprInt(0x12, 32))),
ExprOp('*', a, b, c, ExprInt(-0x12, 32))),
(a | ExprInt(0xffffffff, 32),
ExprInt(0xffffffff, 32)),
(ExprCond(a, ExprInt(1, 32), ExprInt(2, 32)) * ExprInt(4, 32),
Expand Down Expand Up @@ -443,6 +446,21 @@ def check(expr_in, expr_out):
(ExprOp("signExt_16", ExprInt(0x8, 8)), ExprInt(0x8, 16)),
(ExprOp("signExt_16", ExprInt(-0x8, 8)), ExprInt(-0x8, 16)),

(- (i2*a), a * im2),
(a + a, a * i2),
(ExprOp('+', a, a), a * i2),
(ExprOp('+', a, a, a), a * i3),
((a<<i1) - a, a),
((a<<i1) - (a<<i2), a*im2),
((a<<i1) - a - a, i0),
((a<<i2) - (a<<i1) - (a<<i1), i0),
((a<<i2) - a*i3, a),
(((a+b) * i3) - (a + b), (a+b) * i2),
(((a+b) * i2) + a + b, (a+b) * i3),
(((a+b) * i3) - a - b, (a+b) * i2),
(((a+b) * i2) - a - b, a+b),
(((a+b) * i2) - i2 * a - i2 * b, i0),


]

Expand Down

0 comments on commit b7e9a81

Please sign in to comment.