diff --git a/manticore/core/state.py b/manticore/core/state.py index 75bac0162..157e4a550 100644 --- a/manticore/core/state.py +++ b/manticore/core/state.py @@ -1,14 +1,14 @@ import copy import logging +import typing +from typing import List, Sequence -from typing import List, Tuple, Sequence - -from .smtlib import solver, Bool, issymbolic, BitVecConstant +from .plugin import StateDescriptor +from .smtlib import Bool, issymbolic, BitVecConstant from .smtlib.expression import Expression +from ..utils import config from ..utils.event import Eventful from ..utils.helpers import PickleSerializer -from ..utils import config -from .plugin import StateDescriptor consts = config.get_group("core") consts.add( @@ -363,7 +363,14 @@ def new_symbolic_value(self, nbits, label=None, taint=frozenset()): self._input_symbols.append(expr) return expr - def concretize(self, symbolic, policy, maxcount=7): + def concretize( + self, + symbolic: Expression, + policy: str, + maxcount: int = 7, + *, + additional_symbolics: typing.Optional[typing.List[Expression]] = None, + ): """This finds a set of solutions for symbolic using policy. This limits the number of solutions returned to `maxcount` to avoid @@ -378,9 +385,9 @@ def concretize(self, symbolic, policy, maxcount=7): if policy == "MINMAX": vals = self._solver.minmax(self._constraints, symbolic) elif policy == "MAX": - vals = (self._solver.max(self._constraints, symbolic),) + vals = [self._solver.max(self._constraints, symbolic)] elif policy == "MIN": - vals = (self._solver.min(self._constraints, symbolic),) + vals = [self._solver.min(self._constraints, symbolic)] elif policy == "SAMPLED": m, M = self._solver.minmax(self._constraints, symbolic) vals += [m, M] @@ -402,22 +409,34 @@ def concretize(self, symbolic, policy, maxcount=7): elif policy == "OPTIMISTIC": logger.info("Optimistic case when forking") if self._solver.can_be_true(self._constraints, symbolic): - vals = (True,) + vals = [True] else: # We assume the path constraint was feasible to begin with - vals = (False,) + vals = [False] elif policy == "PESSIMISTIC": logger.info("Pessimistic case when forking") if self._solver.can_be_true(self._constraints, symbolic == False): - vals = (False,) + vals = [False] else: # We assume the path constraint was feasible to begin with - vals = (True,) + vals = [True] else: assert policy == "ALL" - vals = self._solver.get_all_values( - self._constraints, symbolic, maxcnt=maxcount, silent=True - ) + if additional_symbolics is not None: + logger.debug( + "Additinal symbolics", additional_symbolics, "used with expression", symbolic + ) + val_1 = self._solver.get_all_values( + self._constraints, symbolic, maxcnt=maxcount, silent=True + ) + val_2 = self._solver.get_all_values( + self._constraints, additional_symbolics, maxcnt=maxcount, silent=True + ) + return list(zip(val_1, val_2)) + else: + vals = self._solver.get_all_values( + self._constraints, symbolic, maxcnt=maxcount, silent=True + ) return tuple(set(vals)) diff --git a/manticore/ethereum/manticore.py b/manticore/ethereum/manticore.py index eb5a0b0ae..0941cf975 100644 --- a/manticore/ethereum/manticore.py +++ b/manticore/ethereum/manticore.py @@ -52,6 +52,11 @@ 60 * 60, "Default timeout for matching sha3 for unsound states (see unsound symbolication).", ) +consts.add( + "lazymode", + False, # Experimental, so disabled by default for now + "Only call the solver when it is absolutely necessary to generate testcases.", +) consts.add( "events", False, @@ -403,6 +408,7 @@ def __init__(self, plugins=None, **kwargs): # make the ethereum world state world = evm.EVMWorld(constraints) initial_state = State(constraints, world) + self._lazy_evaluation = consts.lazymode super().__init__(initial_state, **kwargs) if plugins is not None: for p in plugins: @@ -1426,10 +1432,22 @@ def match(state, func, symbolic_pairs, concrete_pairs, start=None): # Ok all functions had a match for current state return state.can_be_true(True) - def fix_unsound_symbolication(self, state): + def fix_unsound_lazy(self, state): + return state.can_be_true(True) + + def is_sound(self, state): soundcheck = state.context.get("soundcheck", None) if soundcheck is not None: return soundcheck + + if consts.lazymode: + state.context["soundcheck"] = self.fix_unsound_lazy(state) + else: + state.context["soundcheck"] = True + + if not state.context["soundcheck"]: + return state.context["soundcheck"] # no need to keep checking + if consts.sha3 is consts.sha3.symbolicate: state.context["soundcheck"] = self.fix_unsound_symbolication_sound(state) elif consts.sha3 is consts.sha3.fake: @@ -1602,7 +1620,7 @@ def _generate_testcase_ex(self, state, message="", name="user"): :rtype: bool """ # Refuse to generate a testcase from an unsound state - if not self.fix_unsound_symbolication(state): + if not self.is_sound(state): raise ManticoreError("Trying to generate a testcase out of an unsound state path") blockchain = state.platform @@ -1741,7 +1759,7 @@ def finalize(self, procs=None, only_alive_states=False): def finalizer(state_id): st = self._load(state_id) - if self.fix_unsound_symbolication(st): + if self.is_sound(st): last_tx = st.platform.last_transaction # Do not generate killed state if only_alive_states is True if only_alive_states and last_tx.result in {"REVERT", "THROW", "TXERROR"}: @@ -1772,7 +1790,7 @@ def worker_finalize(q): global_findings = set() for state in self.all_states: - if not self.fix_unsound_symbolication(state): + if not self.is_sound(state): continue for detector in self.detectors.values(): for address, pc, finding, at_init, constraint in detector.get_findings(state): @@ -1900,7 +1918,7 @@ def ready_sound_states(self): _ready_states = self._ready_states for state_id in _ready_states: state = self._load(state_id) - if self.fix_unsound_symbolication(state): + if self.is_sound(state): yield state # Re-save the state in case the user changed its data self._save(state, state_id=state_id) @@ -1922,7 +1940,7 @@ def all_sound_states(self): _ready_states = self._ready_states for state_id in _ready_states: state = self._load(state_id) - if self.fix_unsound_symbolication(state): + if self.is_sound(state): yield state # Re-save the state in case the user changed its data self._save(state, state_id=state_id) @@ -1937,7 +1955,7 @@ def fix_unsound_all(self, procs=None): # Fix unsoundness in all states def finalizer(state_id): state = self._load(state_id) - self.fix_unsound_symbolication(state) + self.is_sound(state) self._save(state, state_id=state_id) def worker_finalize(q): @@ -1959,3 +1977,108 @@ def worker_finalize(q): for proc in report_workers: proc.join() + + def enable_lazy_evaluation(self): + """ + Enable lazy evaluation + :return: + """ + self._lazy_evaluation = True + + def disable_lazy_evaluation(self): + """ + Enable lazy evaluation + :return: + """ + self._lazy_evaluation = False + + @property + def lazy_evaluation(self) -> bool: + return self._lazy_evaluation + + def _fork(self, state, expression, policy="ALL", setstate=None): + """ + Fork state on expression concretizations. + Using policy build a list of solutions for expression. + For the state on each solution setting the new state with setstate + For example if expression is a Bool it may have 2 solutions. True or False. + Parent + (expression = ??) + Child1 Child2 + (expression = True) (expression = False) + setstate(True) setstate(False) + The optional setstate() function is supposed to set the concrete value + in the child state. + Parent state is removed from the busy list and the child states are added + to the ready list. + """ + assert isinstance(expression, Expression), f"{type(expression)} is not an Expression" + + if setstate is None: + + def setstate(x, y): + pass + + if ( + self._lazy_evaluation + and state.platform.current_vm + and state.platform.current_vm.jumpi_false_branch + ): + + solutions = [ + (state.platform.current_vm.jumpi_false_branch, False), + (state.platform.current_vm.jumpi_true_branch, True), + ] + else: + + if state.platform.current_vm and isinstance( + state.platform.current_vm.need_check_jumpdest(), (Expression, bool) + ): + solutions = state.concretize( + expression, + policy, + additional_symbolics=state.platform.current_vm.need_check_jumpdest(), + ) + else: + solutions = [(s, False) for s in state.concretize(expression, policy)] + + if not solutions: + raise ManticoreError("Forking on unfeasible constraint set") + + logger.debug( + "Forking. Policy: %s. Values: %s", + policy, + ", ".join(f"0x{sol:x}" for sol, _ in solutions), + ) + + self._publish("will_fork_state", state, expression, solutions, policy) + + # Build and enqueue a state for each solution + children = [] + for solution in solutions: + with state as new_state: + new_value, jump_cond = solution + new_state.constrain(expression == new_value) + + # and set the PC of the new state to the concrete pc-dest + # (or other register or memory address to concrete) + setstate(new_state, new_value) + # TODO: Ideally "jump_cond" should be in the VM and not the platform + # However, platform.current_vm is not yet created + # So not sure how to do it + new_state.platform.last_ins_was_true_jumpi = jump_cond + + # enqueue new_state, assign new state id + new_state_id = self._put_state(new_state) + + # maintain a list of children for logging purpose + children.append(new_state_id) + + self._publish("did_fork_state", state, expression, solutions, policy, children) + logger.debug("Forking current state %r into states %r", state.id, children) + + with self._lock: + self._busy_states.remove(state.id) + self._remove(state.id) + state._id = None + self._lock.notify_all() diff --git a/manticore/ethereum/verifier.py b/manticore/ethereum/verifier.py index f400d1f1f..56fd8b25c 100644 --- a/manticore/ethereum/verifier.py +++ b/manticore/ethereum/verifier.py @@ -123,6 +123,7 @@ def manticore_verifier( print("# Welcome to manticore-verifier") # Main manticore manager object m = ManticoreEVM() + m.disable_lazy_evaluation() # avoid all human level tx that are marked as constant (have no effect on the storage) filter_out_human_constants = FilterFunctions( regexp=r".*", depth="human", mutability="constant", include=False diff --git a/manticore/platforms/evm.py b/manticore/platforms/evm.py index 37f889096..73ba8eed7 100644 --- a/manticore/platforms/evm.py +++ b/manticore/platforms/evm.py @@ -6,7 +6,7 @@ import copy import inspect from functools import wraps -from typing import List, Set, Tuple, Union +from typing import List, Set, Tuple, Union, Optional from ..platforms.platform import * from ..core.smtlib import ( SelectedSolver, @@ -829,6 +829,14 @@ def extend_with_zeroes(b): self._temp_call_gas = None self._failed = False + # Use to keep track of the jumpi destination + # Only save PUSH X values + # self._concrete_stack: List[Optional[int]] = [] + # Save the JUMPI false branch, or None if it was not a JUMPI + self._jumpi_false_branch: Optional[int] = None + self._jumpi_true_branch: Optional[int] = None + self._jump_cond = False + def fail_if(self, failed): self._failed = Operators.OR(self._failed, failed) @@ -872,6 +880,20 @@ def constraints(self, constraints): def gas(self): return Operators.EXTRACT(self._gas, 0, 256) + @property + def jumpi_false_branch(self) -> Optional[int]: + """ + Return the JUMPI false branch. Return None if the last instruction was not a JUMPI + """ + return self._jumpi_false_branch + + @property + def jumpi_true_branch(self) -> Optional[int]: + """ + Return the JUMPI false branch. Return None if the last instruction was not a JUMPI + """ + return self._jumpi_true_branch + def __getstate__(self): state = super().__getstate__() state["sha3"] = self._sha3 @@ -1187,6 +1209,9 @@ def _handler(self, *arguments): implementation = getattr(self, current.semantics, None) if implementation is None: raise TerminateState(f"Instruction not implemented {current.semantics}", testcase=True) + + self._jumpi_false_branch = None + self._need_check_jumpdest = None return implementation(*arguments) def _checkpoint(self): @@ -1224,6 +1249,9 @@ def _rollback(self): self._allocated = allocated self._checkpoint_data = None + def need_check_jumpdest(self): + return self._need_check_jumpdest + def _set_check_jmpdest(self, flag=True): """ Next instruction must be a JUMPDEST iff `flag` holds. @@ -1242,37 +1270,14 @@ def _check_jmpdest(self): already constrained to a single concrete value. """ # If pc is already pointing to a JUMPDEST thre is no need to check. - pc = self.pc.value if isinstance(self.pc, Constant) else self.pc - if pc in self._valid_jumpdests: - self._need_check_jumpdest = False - return - - should_check_jumpdest = simplify(self._need_check_jumpdest) - if isinstance(should_check_jumpdest, Constant): - should_check_jumpdest = should_check_jumpdest.value - elif issymbolic(should_check_jumpdest): - self._publish("will_solve", self.constraints, should_check_jumpdest, "get_all_values") - should_check_jumpdest_solutions = SelectedSolver.instance().get_all_values( - self.constraints, should_check_jumpdest - ) - self._publish( - "did_solve", - self.constraints, - should_check_jumpdest, - "get_all_values", - should_check_jumpdest_solutions, - ) - if len(should_check_jumpdest_solutions) != 1: - raise EthereumError("Conditional not concretized at JMPDEST check") - should_check_jumpdest = should_check_jumpdest_solutions[0] - - # If it can be solved only to False just set it False. If it can be solved - # only to True, process it and also set it to False self._need_check_jumpdest = False - if should_check_jumpdest: + if self._world.last_ins_was_true_jumpi: + pc = self.pc.value if isinstance(self.pc, Constant) else self.pc if pc not in self._valid_jumpdests: self._throw() + self._world.last_ins_was_true_jumpi = False + return def _advance(self, result=None, exception=False): if self._checkpoint_data is None: @@ -2091,6 +2096,8 @@ def JUMPI(self, dest, cond): """Conditionally alter the program counter""" # TODO(feliam) If dest is Constant we do not need to 3 queries. There would # be only 2 options + self._jumpi_false_branch = self.pc + self.instruction.size + self._jumpi_true_branch = dest self.pc = Operators.ITEBV(256, cond != 0, dest, self.pc + self.instruction.size) # This set ups a check for JMPDEST in the next instruction if cond != 0 @@ -2438,6 +2445,8 @@ def __init__(self, constraints, fork=DEFAULT_FORK, **kwargs): self._fork = fork self._block_header = None self.start_block() + # True if the last instruction was a JUMPI and the true branch was taken + self._last_ins_was_true_jumpi = False def __getstate__(self): state = super().__getstate__() @@ -2450,6 +2459,7 @@ def __getstate__(self): state["_transactions"] = self._transactions state["_fork"] = self._fork state["_block_header"] = self._block_header + state["_last_ins_was_true_jumpi"] = self._last_ins_was_true_jumpi return state @@ -2464,6 +2474,7 @@ def __setstate__(self, state): self._transactions = state["_transactions"] self._fork = state["_fork"] self._block_header = state["_block_header"] + self._last_ins_was_true_jumpi = state["_last_ins_was_true_jumpi"] for _, _, _, _, vm in self._callstack: self.forward_events_from(vm) @@ -2548,6 +2559,14 @@ def constraints(self, constraints): if self.current_vm: self.current_vm.constraints = constraints + @property + def last_ins_was_true_jumpi(self) -> Optional[bool]: + return self._last_ins_was_true_jumpi + + @last_ins_was_true_jumpi.setter + def last_ins_was_true_jumpi(self, value): + self._last_ins_was_true_jumpi = value + @property def evmfork(self): return self._fork diff --git a/tests/auto_generators/make_VMTests.py b/tests/auto_generators/make_VMTests.py index bd32b30a4..b831dc322 100644 --- a/tests/auto_generators/make_VMTests.py +++ b/tests/auto_generators/make_VMTests.py @@ -212,6 +212,7 @@ def did_close_transaction_callback(self, state, tx): price={price})""" body += f""" + m.finalize() for state in m.all_states: world = state.platform self.assertEqual(used_gas_plugin.used_gas, {blockheader['gasUsed']}) diff --git a/tests/ethereum/test_general.py b/tests/ethereum/test_general.py index 72264c741..80cdb0024 100644 --- a/tests/ethereum/test_general.py +++ b/tests/ethereum/test_general.py @@ -819,7 +819,7 @@ def test_function_name_with_signature(self): self.mevm.make_symbolic_value(), signature="(uint256,uint256)", ) - for st in self.mevm.all_states: + for st in self.mevm.all_sound_states: z = st.solve_one(st.platform.transactions[1].return_data) break self.assertEqual(ABI.deserialize("(uint256)", z)[0], 2) @@ -1340,7 +1340,7 @@ def test_preconstraints(self): ) m.transaction(caller=creator_account, address=contract_account, data=data, value=value) - results = [state.platform.all_transactions[-1].result for state in m.all_states] + results = [state.platform.all_transactions[-1].result for state in m.all_sound_states] self.assertListEqual(sorted(results), ["STOP", "STOP"]) def test_plugins_enable(self): diff --git a/tests/ethereum/test_sha3.py b/tests/ethereum/test_sha3.py index 1791ac1e1..b91bf59e8 100644 --- a/tests/ethereum/test_sha3.py +++ b/tests/ethereum/test_sha3.py @@ -56,7 +56,7 @@ def test_example1(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -87,7 +87,7 @@ def test_example2(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -118,7 +118,7 @@ def test_example3(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -153,7 +153,7 @@ def test_example4(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -190,7 +190,7 @@ def test_example5(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -226,7 +226,7 @@ def test_example6(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -263,7 +263,7 @@ def test_example7(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -300,7 +300,7 @@ def test_example8(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -335,7 +335,7 @@ def test_essence1(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue m.generate_testcase(st) @@ -368,7 +368,7 @@ def test_essence2(self): contract.foo(x) found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue @@ -407,7 +407,7 @@ def test_essence3(self): contract.foo(x2) for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue @@ -459,7 +459,7 @@ def test_example_concrete_1(self): found = 0 for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue found += len(st.platform.logs) @@ -514,7 +514,7 @@ def test_essence3(self): contract.foo(x2) for st in m.all_states: - if not m.fix_unsound_symbolication(st): + if not m.is_sound(st): m.kill_state(st) continue