Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Slots on all Expression objects #2394

Merged
merged 13 commits into from
Mar 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 107 additions & 28 deletions manticore/core/smtlib/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import re
import copy
from typing import Union, Optional, Dict, List
from typing import Union, Optional, Dict, Tuple


class ExpressionException(SmtlibError):
Expand All @@ -15,10 +15,59 @@ class ExpressionException(SmtlibError):
pass


class Expression:
class XSlotted(type):
"""
Metaclass that will propagate slots on multi-inheritance classes
Every class should define __xslots__ (instead of __slots__)

class Base(object, metaclass=XSlotted, abstract=True):
pass

class A(Base, abstract=True):
__xslots__ = ('a',)
pass

class B(Base, abstract=True):
__xslots__ = ('b',)
pass

class C(A, B):
pass

# Normal case / baseline
class X(object):
__slots__ = ('a', 'b')

c = C()
c.a = 1
c.b = 2

x = X()
x.a = 1
x.b = 2

import sys
print (sys.getsizeof(c),sys.getsizeof(x)) #same value
"""

def __new__(cls, clsname, bases, attrs, abstract=False):
xslots = frozenset(attrs.get("__xslots__", ()))
# merge the xslots of all the bases with the one defined here
for base in bases:
xslots = xslots.union(getattr(base, "__xslots__", ()))
attrs["__xslots__"] = tuple(xslots)
if abstract:
attrs["__slots__"] = tuple()
else:
attrs["__slots__"] = attrs["__xslots__"]
attrs["__hash__"] = object.__hash__
return super().__new__(cls, clsname, bases, attrs)


class Expression(object, metaclass=XSlotted, abstract=True):
""" Abstract taintable Expression. """

__slots__ = ["_taint"]
__xslots__: Tuple[str, ...] = ("_taint",)

def __init__(self, taint: Union[tuple, frozenset] = ()):
if self.__class__ is Expression:
Expand Down Expand Up @@ -114,8 +163,8 @@ def taint_with(arg, *taints, value_bits=256, index_bits=256):

###############################################################################
# Booleans
class Bool(Expression):
__slots__: List[str] = []
class Bool(Expression, abstract=True):
"""Bool expressions represent symbolic value of truth"""

def __init__(self, *operands, **kwargs):
super().__init__(*operands, **kwargs)
Expand Down Expand Up @@ -169,7 +218,7 @@ def __bool__(self):


class BoolVariable(Bool):
__slots__ = ["_name"]
__xslots__: Tuple[str, ...] = ("_name",)

def __init__(self, name: str, *args, **kwargs):
assert " " not in name
Expand All @@ -195,11 +244,11 @@ def declaration(self):


class BoolConstant(Bool):
__slots__ = ["_value"]
__xslots__: Tuple[str, ...] = ("_value",)

def __init__(self, value: bool, *args, **kwargs):
self._value = value
super().__init__(*args, **kwargs)
self._value = value

def __bool__(self):
return self.value
Expand All @@ -209,8 +258,10 @@ def value(self):
return self._value


class BoolOperation(Bool):
__slots__ = ["_operands"]
class BoolOperation(Bool, abstract=True):
""" An operation that results in a Bool """

__xslots__: Tuple[str, ...] = ("_operands",)

def __init__(self, *operands, **kwargs):
self._operands = operands
Expand Down Expand Up @@ -250,10 +301,10 @@ def __init__(self, cond: "Bool", true: "Bool", false: "Bool", **kwargs):
super().__init__(cond, true, false, **kwargs)


class BitVec(Expression):
""" This adds a bitsize to the Expression class """
class BitVec(Expression, abstract=True):
""" BitVector expressions have a fixed bit size """

__slots__ = ["size"]
__xslots__: Tuple[str, ...] = ("size",)

def __init__(self, size, *operands, **kwargs):
super().__init__(*operands, **kwargs)
Expand Down Expand Up @@ -456,7 +507,7 @@ def Bool(self):


class BitVecVariable(BitVec):
__slots__ = ["_name"]
__xslots__: Tuple[str, ...] = ("_name",)

def __init__(self, size: int, name: str, *args, **kwargs):
assert " " not in name
Expand All @@ -482,7 +533,7 @@ def declaration(self):


class BitVecConstant(BitVec):
__slots__ = ["_value"]
__xslots__: Tuple[str, ...] = ("_value",)

def __init__(self, size: int, value: int, *args, **kwargs):
MASK = (1 << size) - 1
Expand Down Expand Up @@ -512,8 +563,10 @@ def signed_value(self):
return self._value


class BitVecOperation(BitVec):
__slots__ = ["_operands"]
class BitVecOperation(BitVec, abstract=True):
""" An operation that results in a BitVec """

__xslots__: Tuple[str, ...] = ("_operands",)

def __init__(self, size, *operands, **kwargs):
ehennenfent marked this conversation as resolved.
Show resolved Hide resolved
self._operands = operands
Expand Down Expand Up @@ -670,8 +723,15 @@ def __init__(self, a, b, *args, **kwargs):

###############################################################################
# Array BV32 -> BV8 or BV64 -> BV8
class Array(Expression):
__slots__ = ["_index_bits", "_index_max", "_value_bits"]
class Array(Expression, abstract=True):
"""An Array expression is an unmutable mapping from bitvector to bitvector

array.index_bits is the number of bits used for addressing a value
array.value_bits is the number of bits used in the values
array.index_max counts the valid indexes starting at 0. Accessing outside the bound is undefined
"""

__xslots__: Tuple[str, ...] = ("_index_bits", "_index_max", "_value_bits")

def __init__(
self, index_bits: int, index_max: Optional[int], value_bits: int, *operands, **kwargs
Expand Down Expand Up @@ -764,7 +824,7 @@ def store(self, index, value):

def write(self, offset, buf):
if not isinstance(buf, (Array, bytearray)):
raise TypeError("Array or bytearray expected got {:s}".format(type(buf)))
raise TypeError(f"Array or bytearray expected got {type(buf)}")
arr = self
for i, val in enumerate(buf):
arr = arr.store(offset + i, val)
Expand Down Expand Up @@ -820,18 +880,18 @@ def read_BE(self, address, size):
bytes = []
for offset in range(size):
bytes.append(self.get(address + offset, 0))
return BitVecConcat(size * self.value_bits, *bytes)
return BitVecConcat(size * self.value_bits, tuple(bytes))

def read_LE(self, address, size):
address = self.cast_index(address)
bytes = []
for offset in range(size):
bytes.append(self.get(address + offset, 0))
return BitVecConcat(size * self.value_bits, *reversed(bytes))
return BitVecConcat(size * self.value_bits, tuple(reversed(bytes)))

def write_BE(self, address, value, size):
address = self.cast_index(address)
value = BitVec(size * self.value_bits).cast(value)
value = BitVecConstant(size * self.value_bits, value=0).cast(value)
array = self
for offset in range(size):
array = array.store(
Expand All @@ -842,7 +902,7 @@ def write_BE(self, address, value, size):

def write_LE(self, address, value, size):
address = self.cast_index(address)
value = BitVec(size * self.value_bits).cast(value)
value = BitVecConstant(size * self.value_bits, value=0).cast(value)
array = self
for offset in reversed(range(size)):
array = array.store(
Expand Down Expand Up @@ -903,7 +963,7 @@ def __radd__(self, other):


class ArrayVariable(Array):
__slots__ = ["_name"]
__xslots__: Tuple[str, ...] = ("_name",)

def __init__(self, index_bits, index_max, value_bits, name, *args, **kwargs):
assert " " not in name
Expand All @@ -929,7 +989,9 @@ def declaration(self):


class ArrayOperation(Array):
__slots__ = ["_operands"]
"""An operation that result in an Array"""

__xslots__: Tuple[str, ...] = ("_operands",)

def __init__(self, array: Array, *operands, **kwargs):
self._operands = (array, *operands)
Expand Down Expand Up @@ -989,6 +1051,8 @@ def __setstate__(self, state):


class ArraySlice(ArrayOperation):
__xslots__: Tuple[str, ...] = ("_slice_offset", "_slice_size")

def __init__(
self, array: Union["Array", "ArrayProxy"], offset: int, size: int, *args, **kwargs
):
Expand Down Expand Up @@ -1033,6 +1097,15 @@ def store(self, index, value):


class ArrayProxy(Array):
__xslots__: Tuple[str, ...] = (
"constraints",
"_default",
"_concrete_cache",
"_written",
"_array",
"_name",
)

def __init__(self, array: Array, default: Optional[int] = None):
self._default = default
self._concrete_cache: Dict[int, int] = {}
Expand Down Expand Up @@ -1229,7 +1302,7 @@ def get(self, index, default=None):


class ArraySelect(BitVec):
__slots__ = ["_operands"]
__xslots__: Tuple[str, ...] = ("_operands",)

def __init__(self, array: "Array", index: "BitVec", *operands, **kwargs):
assert index.size == array.index_bits
Expand Down Expand Up @@ -1257,20 +1330,26 @@ def __repr__(self):


class BitVecSignExtend(BitVecOperation):
__xslots__: Tuple[str, ...] = ("extend",)

def __init__(self, operand: "BitVec", size_dest: int, *args, **kwargs):
assert size_dest >= operand.size
super().__init__(size_dest, operand, *args, **kwargs)
self.extend = size_dest - operand.size


class BitVecZeroExtend(BitVecOperation):
__xslots__: Tuple[str, ...] = ("extend",)

def __init__(self, size_dest: int, operand: "BitVec", *args, **kwargs):
assert size_dest >= operand.size
super().__init__(size_dest, operand, *args, **kwargs)
self.extend = size_dest - operand.size


class BitVecExtract(BitVecOperation):
__xslots__: Tuple[str, ...] = ("_begining", "_end")

def __init__(self, operand: "BitVec", offset: int, size: int, *args, **kwargs):
assert offset >= 0 and offset + size <= operand.size
super().__init__(size, operand, *args, **kwargs)
Expand All @@ -1291,7 +1370,7 @@ def end(self):


class BitVecConcat(BitVecOperation):
def __init__(self, size_dest: int, *operands, **kwargs):
def __init__(self, size_dest: int, operands: Tuple, **kwargs):
assert all(isinstance(x, BitVec) for x in operands)
assert size_dest == sum(x.size for x in operands)
super().__init__(size_dest, *operands, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion manticore/core/smtlib/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def cast(x):
return BitVecConstant(arg_size, x)
return x

return BitVecConcat(total_size, *list(map(cast, args)))
return BitVecConcat(total_size, tuple(map(cast, args)))
else:
return args[0]
else:
Expand Down
4 changes: 2 additions & 2 deletions manticore/core/smtlib/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def visit_BitVecConcat(self, expression, *operands):
if last_o is not None:
new_operands.append(last_o)
if changed:
return BitVecConcat(expression.size, *new_operands)
return BitVecConcat(expression.size, tuple(new_operands))

op = operands[0]
value = None
Expand Down Expand Up @@ -648,7 +648,7 @@ def visit_BitVecExtract(self, expression, *operands):
if size == 0:
assert expression.size == sum([x.size for x in new_operands])
return BitVecConcat(
expression.size, *reversed(new_operands), taint=expression.taint
expression.size, tuple(reversed(new_operands)), taint=expression.taint
)

if begining >= item.size:
Expand Down
4 changes: 2 additions & 2 deletions tests/ethereum/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from manticore.core.plugin import Plugin
from manticore.core.smtlib import ConstraintSet, operators
from manticore.core.smtlib import Z3Solver
from manticore.core.smtlib.expression import BitVec
from manticore.core.smtlib.expression import BitVec, BitVecVariable
from manticore.core.smtlib.visitors import to_constant
from manticore.core.state import TerminateState
from manticore.ethereum import (
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def will_evm_execute_instruction_callback(self, state, i, *args, **kwargs):

class EthHelpersTest(unittest.TestCase):
def setUp(self):
self.bv = BitVec(256)
self.bv = BitVecVariable(256, name="BVV")

def test_concretizer(self):
policy = "SOME_NONSTANDARD_POLICY"
Expand Down
4 changes: 2 additions & 2 deletions tests/native/test_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from manticore.core.smtlib import Bool, BitVecConstant
from manticore.core.smtlib import Bool, BoolVariable, BitVecConstant
from manticore.native.cpu.register import Register


Expand Down Expand Up @@ -47,7 +47,7 @@ def test_bool_write_nonflag(self):

def test_Bool(self):
r = Register(32)
b = Bool()
b = BoolVariable(name="B")
r.write(b)
self.assertIs(r.read(), b)

Expand Down
Binary file removed tests/other/data/ErrRelated.pkl.gz
Binary file not shown.
Loading