Skip to content

Commit

Permalink
Make Structure.fields a dictionary (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
Schamper committed Nov 26, 2023
1 parent ea06eaf commit dbd372a
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 46 deletions.
2 changes: 1 addition & 1 deletion dissect/cstruct/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def compile(self, structure: type[Structure]) -> type[Structure]:
return structure

try:
structure._read = self.compile_read(structure.fields, structure.__name__, structure.__align__)
structure._read = self.compile_read(structure.__fields__, structure.__name__, structure.__align__)
structure.__compiled__ = True
except Exception as e:
# Silently ignore, we didn't compile unfortunately
Expand Down
10 changes: 8 additions & 2 deletions dissect/cstruct/cstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,13 @@ def resolve(self, name: str) -> MetaType:
raise ResolveError(f"Recursion limit exceeded while resolving type {name}")

def _make_type(
self, name: str, bases: Iterator[object], size: int, *, alignment: int = None, attrs: dict[str, Any] = None
self,
name: str,
bases: Iterator[object],
size: Optional[int],
*,
alignment: int = None,
attrs: dict[str, Any] = None,
) -> type[BaseType]:
attrs = attrs or {}
attrs.update(
Expand Down Expand Up @@ -387,7 +393,7 @@ def _make_union(
def ctypes(structure: Structure) -> _ctypes.Structure:
"""Create ctypes structures from cstruct structures."""
fields = []
for field in structure.fields:
for field in structure.__fields__:
t = ctypes_type(field.type)
fields.append((field.name, t))

Expand Down
2 changes: 1 addition & 1 deletion dissect/cstruct/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _struct(self, tokens: TokenConsumer, register: bool = False) -> None:
if self.compiled and "nocompile" not in tokens.flags:
st = compiler.compile(st)
else:
st.fields.extend(fields)
st.__fields__.extend(fields)
st.commit()

# This is pretty dirty
Expand Down
6 changes: 5 additions & 1 deletion dissect/cstruct/types/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ class MetaType(type):
"""Base metaclass for cstruct type classes."""

cs: cstruct
size: int
"""The cstruct instance this type class belongs to."""
size: Optional[int]
"""The size of the type in bytes. Can be ``None`` for dynamic sized types."""
dynamic: bool
"""Whether or not the type is dynamically sized."""
alignment: int
"""The alignment of the type in bytes."""

def __call__(cls, *args, **kwargs) -> Union[MetaType, BaseType]:
"""Adds support for ``TypeClass(bytes | file-like object)`` parsing syntax."""
Expand Down
48 changes: 26 additions & 22 deletions dissect/cstruct/types/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,16 @@ class StructureMetaType(MetaType):

# TODO: resolve field types in _update_fields, remove resolves elsewhere?

fields: list[Field]
fields: dict[str, Field]
"""Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures."""
lookup: dict[str, Field]
"""Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields."""
__fields__: list[Field]
"""List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth."""

__anonymous__: bool
# Internal
__align__: bool
__lookup__: dict[str, Field]
__anonymous__: bool
__updating__ = False
__compiled__ = False

Expand All @@ -50,16 +54,16 @@ def __new__(metacls, cls, bases, classdict: dict[str, Any]) -> MetaType:

def __call__(cls, *args, **kwargs) -> Structure:
if (
cls.fields
and len(args) == len(cls.fields) == 1
cls.__fields__
and len(args) == len(cls.__fields__) == 1
and isinstance(args[0], bytes)
and issubclass(cls.fields[0].type, bytes)
and len(args[0]) == cls.fields[0].type.size
and issubclass(cls.__fields__[0].type, bytes)
and len(args[0]) == cls.__fields__[0].type.size
):
# Shortcut for single char/bytes type
return type.__call__(cls, *args, **kwargs)
elif not args and not kwargs:
obj = cls(**{field.name: field.type() for field in cls.fields})
obj = cls(**{field.name: field.type() for field in cls.__fields__})
object.__setattr__(obj, "_values", {})
object.__setattr__(obj, "_sizes", {})
return obj
Expand All @@ -80,11 +84,11 @@ def _update_fields(
raise ValueError(f"Duplicate field name: {field.name}")

if isinstance(field.type, StructureMetaType) and field.type.__anonymous__:
for anon_field in field.type.lookup.values():
for anon_field in field.type.fields.values():
attr = f"{field.name}.{anon_field.name}"
classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr))

lookup.update(field.type.lookup)
lookup.update(field.type.fields)
else:
lookup[field.name] = field

Expand All @@ -93,9 +97,9 @@ def _update_fields(
num_fields = len(lookup)
field_names = lookup.keys()
init_names = raw_lookup.keys()
classdict["fields"] = fields
classdict["lookup"] = lookup
classdict["__lookup__"] = raw_lookup
classdict["fields"] = lookup
classdict["lookup"] = raw_lookup
classdict["__fields__"] = fields
classdict["__bool__"] = _patch_attributes(_make__bool__(num_fields), field_names, 1)

if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType):
Expand Down Expand Up @@ -229,7 +233,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Structure:

result = {}
sizes = {}
for field in cls.fields:
for field in cls.__fields__:
offset = stream.tell()
field_type = cls.cs.resolve(field.type)

Expand Down Expand Up @@ -283,7 +287,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int:
struct_start = stream.tell()
num = 0

for field in cls.fields:
for field in cls.__fields__:
field_type = cls.cs.resolve(field.type)

bit_field_type = (
Expand Down Expand Up @@ -338,7 +342,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int:

def add_field(cls, name: str, type_: BaseType, bits: Optional[int] = None, offset: Optional[int] = None) -> None:
field = Field(name, type_, bits=bits, offset=offset)
cls.fields.append(field)
cls.__fields__.append(field)

if not cls.__updating__:
cls.commit()
Expand All @@ -353,7 +357,7 @@ def start_update(cls) -> ContextManager:
cls.__updating__ = False

def commit(cls) -> None:
classdict = cls._update_fields(cls.fields, cls.__align__)
classdict = cls._update_fields(cls.__fields__, cls.__align__)

for key, value in classdict.items():
setattr(cls, key, value)
Expand All @@ -378,7 +382,7 @@ def __repr__(self) -> str:
values = " ".join(
[
f"{k}={hex(getattr(self, k)) if issubclass(f.type, int) else repr(getattr(self, k))}"
for k, f in self.__class__.lookup.items()
for k, f in self.__class__.fields.items()
]
)
return f"<{self.__class__.__name__} {values}>"
Expand Down Expand Up @@ -425,7 +429,7 @@ def _read_fields(cls, stream: BinaryIO, context: dict[str, Any] = None) -> tuple
offset = 0
buf = io.BytesIO(stream.read(cls.size))

for field in cls.fields:
for field in cls.__fields__:
field_type = cls.cs.resolve(field.type)

start = 0
Expand Down Expand Up @@ -467,7 +471,7 @@ def _write(cls, stream: BinaryIO, data: Union) -> int:
expected_offset = offset + len(cls)

# Sort by largest field
fields = sorted(cls.fields, key=lambda e: len(e.type), reverse=True)
fields = sorted(cls.__fields__, key=lambda e: len(e.type), reverse=True)
anonymous_struct = False

# Try to write by largest field
Expand Down Expand Up @@ -516,7 +520,7 @@ def _rebuild(self, attr: str) -> None:
cur_buf = b"\x00" * self.__class__.size

buf = io.BytesIO(cur_buf)
field = self.__class__.lookup[attr]
field = self.__class__.fields[attr]
if field.offset:
buf.seek(field.offset)
field.type._write(buf, getattr(self, attr))
Expand All @@ -532,7 +536,7 @@ def _update(self) -> None:

def _proxify(self) -> UnionProxy:
def _proxy_structure(value: Structure) -> UnionProxy:
for field in value.__class__.fields:
for field in value.__class__.__fields__:
if issubclass(field.type, Structure):
nested_value = getattr(value, field.name)
proxy = UnionProxy(self, field.name, nested_value)
Expand Down
2 changes: 1 addition & 1 deletion dissect/cstruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _dumpstruct(
ci = 0
out = [f"struct {structure.__class__.__name__}:"]
foreground, background = None, None
for field in structure.__class__.fields:
for field in structure.__class__.__fields__:
if getattr(field.type, "anonymous", False):
continue

Expand Down
16 changes: 8 additions & 8 deletions tests/test_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_align_struct(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert cs.test.__align__
assert cs.test.alignment == 8
assert cs.test.size == 32
Expand All @@ -41,7 +41,7 @@ def test_align_struct(cs: cstruct, compiled: bool):
assert fh.tell() == 32

for name, value in obj._values.items():
assert cs.test.lookup[name].offset == value
assert cs.test.fields[name].offset == value

assert obj.dumps() == buf

Expand Down Expand Up @@ -87,7 +87,7 @@ def test_align_array(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert cs.test.__align__
assert cs.test.alignment == 8
assert cs.test.size == 64
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_align_struct_array(cs: cstruct, compiled: bool):
assert verify_compiled(cs.test, compiled)
assert verify_compiled(cs.array, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert cs.test.__align__
assert cs.test.alignment == 8
assert cs.test.size == 16
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_align_dynamic(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert fields[0].offset == 0
assert fields[1].offset == 2
assert fields[2].offset is None
Expand Down Expand Up @@ -221,7 +221,7 @@ def test_align_nested_struct(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert fields[0].offset == 0x00
assert fields[1].offset == 0x08
assert fields[2].offset == 0x18
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_align_bitfield(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert fields[0].offset == 0x00
assert fields[1].offset is None
assert fields[2].offset == 0x08
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_align_pointer(cs: cstruct, compiled: bool):

assert verify_compiled(cs.test, compiled)

fields = cs.test.fields
fields = cs.test.__fields__
assert cs.test.__align__
assert cs.test.alignment == 8
assert cs.test.size == 24
Expand Down
4 changes: 2 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,8 @@ def test_multipart_type_name(cs: cstruct):
cs.load(cdef)

assert cs.TestEnum.type == cs.resolve("unsigned int")
assert cs.test.fields[0].type == cs.resolve("unsigned int")
assert cs.test.fields[1].type == cs.resolve("unsigned long long")
assert cs.test.__fields__[0].type == cs.resolve("unsigned int")
assert cs.test.__fields__[1].type == cs.resolve("unsigned long long")

with pytest.raises(ResolveError) as exc:
cdef = """
Expand Down
12 changes: 6 additions & 6 deletions tests/test_types_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def TestStruct(cs: cstruct) -> type[Structure]:
def test_structure(TestStruct: type[Structure]):
assert issubclass(TestStruct, Structure)
assert len(TestStruct.fields) == 2
assert TestStruct.lookup["a"].name == "a"
assert TestStruct.lookup["b"].name == "b"
assert TestStruct.fields["a"].name == "a"
assert TestStruct.fields["b"].name == "b"

assert TestStruct.size == 8
assert TestStruct.alignment == 4
Expand Down Expand Up @@ -222,8 +222,8 @@ def test_structure_definitions(cs: cstruct, compiled: bool):
assert cs.test.__name__ == "_test"
assert cs._test.__name__ == "_test"

assert "a" in cs.test.lookup
assert "b" in cs.test.lookup
assert "a" in cs.test.fields
assert "b" in cs.test.fields

with pytest.raises(ParserError):
cdef = """
Expand Down Expand Up @@ -539,5 +539,5 @@ def test_structure_definition_self(cs: cstruct):
"""
cs.load(cdef)

assert issubclass(cs.test.fields[1].type, Pointer)
assert cs.test.fields[1].type.type is cs.test
assert issubclass(cs.test.fields["b"].type, Pointer)
assert cs.test.fields["b"].type.type is cs.test
4 changes: 2 additions & 2 deletions tests/test_types_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def TestUnion(cs: cstruct) -> type[Union]:
def test_union(TestUnion: type[Union]):
assert issubclass(TestUnion, Union)
assert len(TestUnion.fields) == 2
assert TestUnion.lookup["a"].name == "a"
assert TestUnion.lookup["b"].name == "b"
assert TestUnion.fields["a"].name == "a"
assert TestUnion.fields["b"].name == "b"

assert TestUnion.size == 4
assert TestUnion.alignment == 4
Expand Down

0 comments on commit dbd372a

Please sign in to comment.