diff --git a/README.md b/README.md index d16349b..3956c1b 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ assert a.dumps() == d The API to access enum members and their values is similar to that of the native Enum type in Python 3. Functionally, it's best comparable to the IntEnum type. ### Custom types -You can implement your own types by subclassing `BaseType` or `RawType`, and adding them to your cstruct instance with `addtype(name, type)` +You can implement your own types by subclassing `BaseType` or `RawType`, and adding them to your cstruct instance with `add_type(name, type)` ### Custom definition parsers Don't like the C-like definition syntax? Write your own syntax parser! diff --git a/dissect/cstruct/__init__.py b/dissect/cstruct/__init__.py index 2196375..3b795cc 100644 --- a/dissect/cstruct/__init__.py +++ b/dissect/cstruct/__init__.py @@ -1,5 +1,4 @@ from dissect.cstruct.bitbuffer import BitBuffer -from dissect.cstruct.compiler import Compiler from dissect.cstruct.cstruct import cstruct, ctypes, ctypes_type from dissect.cstruct.exceptions import ( Error, @@ -8,18 +7,25 @@ ResolveError, ) from dissect.cstruct.expression import Expression -from dissect.cstruct.types.base import Array, BaseType, RawType -from dissect.cstruct.types.bytesinteger import BytesInteger -from dissect.cstruct.types.chartype import CharType -from dissect.cstruct.types.enum import Enum, EnumInstance -from dissect.cstruct.types.flag import Flag, FlagInstance -from dissect.cstruct.types.instance import Instance -from dissect.cstruct.types.leb128 import LEB128 -from dissect.cstruct.types.packedtype import PackedType -from dissect.cstruct.types.pointer import Pointer, PointerInstance -from dissect.cstruct.types.structure import Field, Structure, Union -from dissect.cstruct.types.voidtype import VoidType -from dissect.cstruct.types.wchartype import WcharType +from dissect.cstruct.types import ( + LEB128, + Array, + BaseType, + Char, + CharArray, + Enum, + Field, + Flag, + Int, + MetaType, + Packed, + Pointer, + Structure, + Union, + Void, + Wchar, + WcharArray, +) from dissect.cstruct.utils import ( dumpstruct, hexdump, @@ -40,31 +46,28 @@ ) __all__ = [ - "Compiler", - "Array", - "Union", - "Field", - "Instance", + "cstruct", + "ctypes", + "ctypes_type", "LEB128", - "Structure", - "Expression", - "PackedType", - "Pointer", - "PointerInstance", - "VoidType", - "WcharType", - "RawType", + "Array", "BaseType", - "CharType", + "Char", + "CharArray", "Enum", - "EnumInstance", + "Expression", + "Field", "Flag", - "FlagInstance", - "BytesInteger", + "Int", + "MetaType", + "Packed", + "Pointer", + "Structure", + "Union", + "Void", + "Wchar", + "WcharArray", "BitBuffer", - "cstruct", - "ctypes", - "ctypes_type", "dumpstruct", "hexdump", "pack", diff --git a/dissect/cstruct/bitbuffer.py b/dissect/cstruct/bitbuffer.py index 5ab0644..d402fb1 100644 --- a/dissect/cstruct/bitbuffer.py +++ b/dissect/cstruct/bitbuffer.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, BinaryIO, Union +from typing import TYPE_CHECKING, BinaryIO if TYPE_CHECKING: - from dissect.cstruct.types import RawType + from dissect.cstruct.types import BaseType class BitBuffer: @@ -17,7 +17,7 @@ def __init__(self, stream: BinaryIO, endian: str): self._buffer = 0 self._remaining = 0 - def read(self, field_type: RawType, bits: Union[int, bytes]) -> int: + def read(self, field_type: BaseType, bits: int) -> int: if self._remaining == 0 or self._type != field_type: self._type = field_type self._remaining = field_type.size * 8 @@ -43,7 +43,7 @@ def read(self, field_type: RawType, bits: Union[int, bytes]) -> int: return v - def write(self, field_type: RawType, data: int, bits: int) -> None: + def write(self, field_type: BaseType, data: int, bits: int) -> None: if self._remaining == 0 or self._type != field_type: if self._type: self.flush() diff --git a/dissect/cstruct/compiler.py b/dissect/cstruct/compiler.py index 2ea6778..48d696b 100644 --- a/dissect/cstruct/compiler.py +++ b/dissect/cstruct/compiler.py @@ -1,413 +1,418 @@ +# Made in Japan + from __future__ import annotations -import keyword -import struct -from collections import OrderedDict -from textwrap import dedent -from typing import TYPE_CHECKING, List +import io +import logging +from enum import Enum +from textwrap import dedent, indent +from types import MethodType +from typing import TYPE_CHECKING, Iterator from dissect.cstruct.bitbuffer import BitBuffer -from dissect.cstruct.expression import Expression from dissect.cstruct.types import ( Array, - BytesInteger, - CharType, - Enum, - EnumInstance, - Field, + ArrayMetaType, + Char, + CharArray, Flag, - FlagInstance, - Instance, - PackedType, + Int, + MetaType, + Packed, Pointer, - PointerInstance, Structure, Union, - WcharType, + Void, + Wchar, + WcharArray, ) +from dissect.cstruct.types.packed import _struct if TYPE_CHECKING: - from dissect.cstruct import cstruct + from dissect.cstruct.cstruct import cstruct + from dissect.cstruct.types.structure import Field + +SUPPORTED_TYPES = ( + Array, + Char, + CharArray, + Enum, + Flag, + Int, + Packed, + Pointer, + Structure, + Void, + Wchar, + WcharArray, +) + +log = logging.getLogger(__name__) + +python_compile = compile + + +def compile(structure: type[Structure]) -> type[Structure]: + return Compiler(structure.cs).compile(structure) class Compiler: - """Compiler for cstruct structures. Creates somewhat optimized parsing code.""" - - TYPES = ( - Structure, - Pointer, - Enum, - Flag, - Array, - PackedType, - CharType, - WcharType, - BytesInteger, - ) - - COMPILE_TEMPLATE = """ -class {name}(Structure): - def __init__(self, cstruct, structure, source=None): - self.structure = structure - self.source = source - super().__init__(cstruct, structure.name, structure.fields, anonymous=structure.anonymous) - - def _read(self, stream, context=None): - r = OrderedDict() - sizes = {{}} - bitreader = BitBuffer(stream, self.cstruct.endian) - -{read_code} - - return Instance(self, r, sizes) - - def add_field(self, name, type_, offset=None): - raise NotImplementedError("Can't add fields to a compiled structure") - - def __repr__(self): - return '' -""" - - def __init__(self, cstruct: cstruct): - self.cstruct = cstruct - - def compile(self, structure: Structure) -> Structure: - if isinstance(structure, Union) or structure.align: - return structure + def __init__(self, cs: cstruct): + self.cs = cs - structure_name = structure.name - if keyword.iskeyword(structure_name): - structure_name += "_" + def compile(self, structure: type[Structure]) -> type[Structure]: + if issubclass(structure, Union): + return structure try: - # Generate struct class based on provided structure type - source = self.gen_struct_class(structure_name, structure) - except TypeError: - return structure + structure._read = self.compile_read(structure.__fields__, structure.__name__, structure.__align__) + structure.__compiled__ = True + except Exception as e: + # Silently ignore, we didn't compile unfortunately + log.debug("Failed to compile %s", structure, exc_info=e) + + return structure - # Create code object that can be executed later on - code_object = compile( - source, - f"", - "exec", - ) - - env = { - "OrderedDict": OrderedDict, - "Structure": Structure, - "Instance": Instance, - "Expression": Expression, - "EnumInstance": EnumInstance, - "FlagInstance": FlagInstance, - "PointerInstance": PointerInstance, - "BytesInteger": BytesInteger, - "BitBuffer": BitBuffer, - "struct": struct, - "range": range, - } - - exec(code_object, env) - klass = env[structure_name] - klass.__name__ = structure.name - return klass(self.cstruct, structure, source) - - def gen_struct_class(self, name: str, structure: Structure) -> str: - blocks = [] - classes = [] - cur_block = [] - read_size = 0 + def compile_read(self, fields: list[Field], name: str | None = None, align: bool = False) -> MethodType: + return _ReadSourceGenerator(self.cs, fields, name, align).generate() + + +class _ReadSourceGenerator: + def __init__(self, cs: cstruct, fields: list[Field], name: str | None = None, align: bool = False): + self.cs = cs + self.fields = fields + self.name = name + self.align = align + + self.field_map: dict[str, Field] = {} + self._token_id = 0 + + def _map_field(self, field: Field) -> str: + token = f"_{self._token_id}" + self.field_map[token] = field + self._token_id += 1 + return token + + def generate(self) -> MethodType: + source = self.generate_source() + symbols = {token: field.type for token, field in self.field_map.items()} + + code = python_compile(source, f"", "exec") + exec(code, {"BitBuffer": BitBuffer, "_struct": _struct, **symbols}, d := {}) + obj = d.popitem()[1] + obj.__source__ = source + + return classmethod(obj) + + def generate_source(self) -> str: + preamble = """ + r = {} + s = {} + """ + + if any(field.bits for field in self.fields): + preamble += "bit_reader = BitBuffer(stream, cls.cs.endian)\n" + + read_code = "\n".join(self._generate_fields()) + + outro = """ + obj = type.__call__(cls, **r) + obj._sizes = s + obj._values = r + + return obj + """ + + code = indent(dedent(preamble).lstrip() + read_code + dedent(outro), " ") + + template = f"def _read(cls, stream, context=None):\n{code}" + return template + + def _generate_fields(self) -> Iterator[str]: + current_offset = 0 + current_block: list[Field] = [] prev_was_bits = False + prev_bits_type = None + bits_remaining = 0 + bits_rollover = False - for field in structure.fields: - field_type = self.cstruct.resolve(field.type) + def flush() -> Iterator[str]: + if current_block: + if self.align and current_block[0].offset is None: + yield f"stream.seek(-stream.tell() & ({current_block[0].alignment} - 1), {io.SEEK_CUR})" - if not isinstance(field_type, self.TYPES): - raise TypeError(f"Unsupported type for compiler: {field_type}") + yield from self._generate_packed(current_block) + current_block[:] = [] - if isinstance(field_type, Structure) or ( - isinstance(field_type, Array) and isinstance(field_type.type, (Structure, Array)) - ): - blocks.append(self.gen_read_block(read_size, cur_block)) - - struct_read = "s = stream.tell()\n" - if isinstance(field_type, Array): - num = field_type.count - - if isinstance(num, Expression): - num = f"max(0, Expression(self.cstruct, '{num.expression}').evaluate(r))" - - struct_read += dedent( - f""" - r['{field.name}'] = [] - for _ in range({num}): - r['{field.name}'].append(self.lookup['{field.name}'].type.type._read(stream, context=r)) - sizes['{field.name}'] = stream.tell() - s - """ - ) - elif isinstance(field_type, Structure) and field_type.anonymous: - struct_read += dedent( - f""" - v = self.lookup["{field.name}"].type._read(stream, context=r) - r.update(v._values) - sizes.update(v._sizes) - """ - ) - else: - struct_read += dedent( - f""" - r['{field.name}'] = self.lookup['{field.name}'].type._read(stream, context=r) - sizes['{field.name}'] = stream.tell() - s - """ - ) - - blocks.append(struct_read) - read_size = 0 - cur_block = [] - continue + def align_to_field(field: Field) -> Iterator[str]: + nonlocal current_offset - if field.bits: - blocks.append(self.gen_read_block(read_size, cur_block)) - if isinstance(field_type, Enum): - bitfield_read = dedent( - f""" - r['{field.name}'] = self.cstruct.{field.type.name}( - bitreader.read(self.cstruct.{field.type.type.name}, {field.bits}) - ) - """ - ) - else: - bitfield_read = f"r['{field.name}'] = bitreader.read(self.cstruct.{field.type.name}, {field.bits})" - blocks.append(bitfield_read) + if field.offset is not None and field.offset != current_offset: + # If a field has a set offset and it's not the same as the current tracked offset, seek to it + yield f"stream.seek({field.offset})" + current_offset = field.offset - read_size = 0 - cur_block = [] - prev_was_bits = True - continue + if self.align and field.offset is None: + yield f"stream.seek(-stream.tell() & ({field.alignment} - 1), {io.SEEK_CUR})" + + for field in self.fields: + field_type = self.cs.resolve(field.type) - if prev_was_bits: - blocks.append("bitreader.reset()") + if not issubclass(field_type, SUPPORTED_TYPES): + raise TypeError(f"Unsupported type for compiler: {field_type}") + + if prev_was_bits and not field.bits: + yield "bit_reader.reset()" prev_was_bits = False + bits_remaining = 0 try: - count = len(field_type) - read_size += count - cur_block.append(field) + size = len(field_type) + is_dynamic = False except TypeError: - if cur_block: - blocks.append(self.gen_read_block(read_size, cur_block)) - - blocks.append(self.gen_dynamic_block(field)) - read_size = 0 - cur_block = [] - - if len(cur_block): - blocks.append(self.gen_read_block(read_size, cur_block)) + size = None + is_dynamic = True + + # Sub structure + if issubclass(field_type, Structure): + yield from flush() + yield from align_to_field(field) + yield from self._generate_structure(field) + + # Array of structures and multi-dimensional arrays + elif issubclass(field_type, (Array, CharArray, WcharArray)) and ( + issubclass(field_type.type, Structure) or isinstance(field_type.type, ArrayMetaType) or is_dynamic + ): + yield from flush() + yield from align_to_field(field) + yield from self._generate_array(field) - read_code = "\n\n".join(blocks) - read_code = "\n".join([" " * 2 + line for line in read_code.split("\n")]) + # Bit fields + elif field.bits: + if not prev_was_bits: + prev_bits_type = field.type + prev_was_bits = True - classes.append(self.COMPILE_TEMPLATE.format(name=name, read_code=read_code)) - return "\n\n".join(classes) + if bits_remaining == 0 or prev_bits_type != field.type: + bits_remaining = (size * 8) - field.bits + bits_rollover = True - def gen_read_block(self, size: int, block: List[str]) -> str: - template = dedent( - f""" - buf = stream.read({size}) - if len(buf) != {size}: raise EOFError() - data = struct.unpack(self.cstruct.endian + '{{}}', buf) - {{}} - """ - ) + yield from flush() + yield from align_to_field(field) + yield from self._generate_bits(field) - read_code = [] - fmt = [] + # Everything else - basic and composite types (and arrays of them) + else: + current_block.append(field) + + if current_offset is not None and size is not None: + if not field.bits or (field.bits and bits_rollover): + current_offset += size + bits_rollover = False + + yield from flush() + + if self.align: + yield f"stream.seek(-stream.tell() & (cls.alignment - 1), {io.SEEK_CUR})" + + def _generate_structure(self, field: Field) -> Iterator[str]: + template = f""" + _s = stream.tell() + r["{field.name}"] = {self._map_field(field)}._read(stream, context=r) + s["{field.name}"] = stream.tell() - _s + """ + + yield dedent(template) + + def _generate_array(self, field: Field) -> Iterator[str]: + template = f""" + _s = stream.tell() + r["{field.name}"] = {self._map_field(field)}._read(stream, context=r) + s["{field.name}"] = stream.tell() - _s + """ + + yield dedent(template) + + def _generate_bits(self, field: Field) -> Iterator[str]: + lookup = self._map_field(field) + read_type = "_t" + field_type = field.type + if issubclass(field_type, (Enum, Flag)): + read_type += ".type" + field_type = field_type.type - cur_type = None - cur_count = 0 + if issubclass(field_type, Char): + field_type = field_type.cs.uint8 + lookup = "cls.cs.uint8" - buf_offset = 0 - data_offset = 0 + template = f""" + _t = {lookup} + r["{field.name}"] = type.__call__(_t, bit_reader.read({read_type}, {field.bits})) + """ - for field in block: - field_type = self.cstruct.resolve(field.type) - read_type = field_type + yield dedent(template) - count = 1 - data_count = 1 + def _generate_packed(self, fields: list[Field]) -> Iterator[str]: + info = list(_generate_struct_info(self.cs, fields, self.align)) + reads = [] - if isinstance(read_type, (Enum, Flag)): - read_type = read_type.type - elif isinstance(read_type, Pointer): - read_type = self.cstruct.pointer + size = 0 + slice_index = 0 + for field, count, _ in info: + if field is None: + # Padding + size += count + continue - if isinstance(field_type, Array): - count = read_type.count - data_count = count - read_type = read_type.type + field_type = self.cs.resolve(field.type) + read_type = _get_read_type(self.cs, field_type) - if isinstance(read_type, (Enum, Flag)): - read_type = read_type.type - elif isinstance(read_type, Pointer): - read_type = self.cstruct.pointer + if issubclass(field_type, (Array, CharArray, WcharArray)): + count = field_type.num_entries + read_type = _get_read_type(self.cs, field_type.type) - if isinstance(read_type, (CharType, WcharType, BytesInteger)): - read_slice = f"{buf_offset}:{buf_offset + + (count * read_type.size)}" + if issubclass(read_type, (Char, Wchar, Int)): + count *= read_type.size + getter = f"buf[{size}:{size + count}]" else: - read_slice = f"{data_offset}:{data_offset + count}" - elif isinstance(read_type, CharType): - read_slice = f"{buf_offset}:{buf_offset + 1}" - elif isinstance(read_type, (WcharType, BytesInteger)): - read_slice = f"{buf_offset}:{buf_offset + read_type.size}" + getter = f"data[{slice_index}:{slice_index + count}]" + slice_index += count + elif issubclass(read_type, (Char, Wchar, Int)): + getter = f"buf[{size}:{size + read_type.size}]" else: - read_slice = str(data_offset) + getter = f"data[{slice_index}]" + slice_index += 1 - if not cur_type: - if isinstance(read_type, PackedType): - cur_type = read_type.packchar + if issubclass(read_type, (Wchar, Int)): + # Types that parse bytes further down to their own type + parser_template = "{type}({getter})" + else: + # All other types can be simply intialized + parser_template = "type.__call__({type}, {getter})" + + # Create the final reading code + if issubclass(field_type, Array): + reads.append(f"_t = {self._map_field(field)}") + reads.append("_et = _t.type") + + if issubclass(field_type.type, Int): + reads.append(f"_b = {getter}") + item_parser = parser_template.format(type="_et", getter=f"_b[i:i + {field_type.type.size}]") + list_comp = f"[{item_parser} for i in range(0, {count}, {field_type.type.size})]" + elif issubclass(field_type.type, Pointer): + item_parser = "_et.__new__(_et, e, stream, r)" + list_comp = f"[{item_parser} for e in {getter}]" else: - cur_type = "x" + item_parser = parser_template.format(type="_et", getter="e") + list_comp = f"[{item_parser} for e in {getter}]" + + parser = f"type.__call__(_t, {list_comp})" + elif issubclass(field_type, CharArray): + parser = f"type.__call__({self._map_field(field)}, {getter})" + elif issubclass(field_type, Pointer): + reads.append(f"_pt = {self._map_field(field)}") + parser = f"_pt.__new__(_pt, {getter}, stream, r)" + else: + parser = parser_template.format(type=self._map_field(field), getter=getter) - if isinstance(read_type, (PackedType, CharType, WcharType, BytesInteger, Enum, Flag)): - char_count = count + reads.append(f'r["{field.name}"] = {parser}') + reads.append(f's["{field.name}"] = {field_type.size}') + reads.append("") # Generates a newline in the resulting code - if isinstance(read_type, (CharType, WcharType, BytesInteger)): - data_count = 0 - pack_char = "x" - char_count *= read_type.size - else: - pack_char = read_type.packchar - - if cur_type != pack_char: - fmt.append(f"{cur_count}{cur_type}") - cur_count = 0 - - cur_count += char_count - cur_type = pack_char - - if isinstance(read_type, BytesInteger): - getter = "BytesInteger.parse(buf[{slice}], {size}, {count}, {signed}, self.cstruct.endian){data_slice}" - - getter = getter.format( - slice=read_slice, - size=read_type.size, - count=count, - signed=read_type.signed, - data_slice="[0]" if count == 1 else "", - ) - elif isinstance(read_type, (CharType, WcharType)): - getter = f"buf[{read_slice}]" - - if isinstance(read_type, WcharType): - getter += ".decode('utf-16-le' if self.cstruct.endian == '<' else 'utf-16-be')" - else: - getter = f"data[{read_slice}]" + size += field_type.size - if isinstance(field_type, (Enum, Flag)): - enum_type = field_type.__class__.__name__ - getter = f"{enum_type}Instance(self.cstruct.{field_type.name}, {getter})" - elif isinstance(field_type, Array) and isinstance(field_type.type, (Enum, Flag)): - enum_type = field_type.type.__class__.__name__ - getter = f"[{enum_type}Instance(self.cstruct.{field_type.type.name}, d) for d in {getter}]" - elif isinstance(field_type, Pointer): - getter = f"PointerInstance(self.cstruct.{field_type.type.name}, stream, {getter}, r)" - elif isinstance(field_type, Array) and isinstance(field_type.type, Pointer): - getter = f"[PointerInstance(self.cstruct.{field_type.type.type.name}, stream, d, r) for d in {getter}]" - elif isinstance(field_type, Array) and isinstance(read_type, PackedType): - getter = f"list({getter})" + fmt = _optimize_struct_fmt(info) + if fmt == "x" or (len(fmt) == 2 and fmt[1] == "x"): + unpack = "" + else: + unpack = f'data = _struct(cls.cs.endian, "{fmt}").unpack(buf)\n' - read_code.append(f"r['{field.name}'] = {getter}") - read_code.append(f"sizes['{field.name}'] = {count * read_type.size}") + template = f""" + buf = stream.read({size}) + if len(buf) != {size}: raise EOFError() + {unpack} + """ - data_offset += data_count - buf_offset += count * read_type.size + yield dedent(template) + "\n".join(reads) - if cur_count: - fmt.append(f"{cur_count}{cur_type}") - return template.format("".join(fmt), "\n".join(read_code)) +def _generate_struct_info(cs: cstruct, fields: list[Field], align: bool = False) -> Iterator[tuple[Field, int, str]]: + if not fields: + return - def gen_dynamic_block(self, field: Field) -> str: - if not isinstance(field.type, Array): - raise TypeError(f"Only Array can be dynamic, got {field.type!r}") + current_offset = fields[0].offset + imaginary_offset = 0 + for field in fields: + # We moved -- probably due to alignment + if field.offset is not None and (drift := field.offset - current_offset) > 0: + yield None, drift, "x" + current_offset += drift - field_type = self.cstruct.resolve(field.type.type) - reader = None + if align and field.offset is None and (drift := -imaginary_offset & (field.alignment - 1)) > 0: + # Assume we started at a correctly aligned boundary + yield None, drift, "x" + imaginary_offset += drift - if isinstance(field_type, (Enum, Flag)): - field_type = field_type.type + count = 1 + read_type = _get_read_type(cs, field.type) + + # Drop voids + if issubclass(read_type, Void): + continue + + # Array of more complex types are handled elsewhere + if issubclass(read_type, (Array, CharArray, WcharArray)): + count = read_type.num_entries + read_type = _get_read_type(cs, read_type.type) + + # Take the pack char for Packed + if issubclass(read_type, Packed): + yield field, count, read_type.packchar + + # Other types are byte based + # We don't actually unpack anything here but slice directly out of the buffer + elif issubclass(read_type, (Char, Wchar, Int)): + yield field, count * read_type.size, "x" + + size = count * read_type.size + imaginary_offset += size + if current_offset is not None: + current_offset += size + + +def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]) -> str: + chars = [] + + current_count = 0 + current_char = None + + for _, count, char in info: + if current_char is None: + current_count = count + current_char = char + continue + + if char != current_char: + if current_count: + chars.append((current_count, current_char)) + current_count = count + current_char = char + else: + current_count += count + + if current_char is not None and current_count: + chars.append((current_count, current_char)) + + return "".join(f"{count if count > 1 else ''}{char}" for count, char in chars) + + +def _get_read_type(cs: cstruct, type_: MetaType | str) -> MetaType: + type_ = cs.resolve(type_) + + if issubclass(type_, (Enum, Flag)): + type_ = type_.type + + if issubclass(type_, Pointer): + type_ = cs.pointer - if not field.type.count: # Null terminated - if isinstance(field_type, PackedType): - reader = dedent( - f""" - t = [] - while True: - d = stream.read({field_type.size}) - if len(d) != {field_type.size}: raise EOFError() - v = struct.unpack(self.cstruct.endian + '{field_type.packchar}', d)[0] - if v == 0: break - t.append(v) - """ - ) - - elif isinstance(field_type, (CharType, WcharType)): - null = "\\x00" * field_type.size - reader = dedent( - f""" - t = [] - while True: - c = stream.read({field_type.size}) - if len(c) != {field_type.size}: raise EOFError() - if c == b'{null}': break - t.append(c) - t = b''.join(t)""" # It's important there's no newline here because of the optional decode - ) - - if isinstance(field_type, WcharType): - reader += ".decode('utf-16-le' if self.cstruct.endian == '<' else 'utf-16-be')" - elif isinstance(field_type, BytesInteger): - reader = dedent( - f""" - t = [] - while True: - d = stream.read({field_type.size}) - if len(d) != {field_type.size}: raise EOFError() - v = BytesInteger.parse(d, {field_type.size}, 1, {field_type.signed}, self.cstruct.endian) - if v == 0: break - t.append(v) - """ - ) - - if isinstance(field_type, (Enum, Flag)): - enum_type = field_type.__class__.__name__ - reader += f"\nt = [{enum_type}Instance(self.cstruct.{field_type.name}, d) for d in t]" - - if not reader: - raise TypeError(f"Couldn't compile a reader for array {field!r}, {field_type!r}.") - - return f"s = stream.tell()\n{reader}\nr['{field.name}'] = t\nsizes['{field.name}'] = stream.tell() - s" - - expr_read = dedent( - f""" - dynsize = max(0, Expression(self.cstruct, "{field.type.count.expression}").evaluate(r)) - buf = stream.read(dynsize * {field_type.size}) - if len(buf) != dynsize * {field_type.size}: raise EOFError() - r['{field.name}'] = {{reader}} - sizes['{field.name}'] = dynsize * {field_type.size} - """ - ) - - if isinstance(field_type, PackedType): - reader = f"list(struct.unpack(self.cstruct.endian + f'{{dynsize}}{field_type.packchar}', buf))" - elif isinstance(field_type, (CharType, WcharType)): - reader = "buf" - if isinstance(field_type, WcharType): - reader += ".decode('utf-16-le' if self.cstruct.endian == '<' else 'utf-16-be')" - elif isinstance(field_type, BytesInteger): - reader = f"BytesInteger.parse(buf, {field_type.size}, dynsize, {field_type.signed}, self.cstruct.endian)" - - if isinstance(field_type, (Enum, Flag)): - enum_type = field_type.__class__.__name__ - reader += f"[{enum_type}Instance(self.cstruct.{field_type.name}, d) for d in {reader}]" - - return expr_read.format(reader=reader, size=None) + return cs.resolve(type_) diff --git a/dissect/cstruct/cstruct.py b/dissect/cstruct/cstruct.py index cf2aaff..81369ab 100644 --- a/dissect/cstruct/cstruct.py +++ b/dissect/cstruct/cstruct.py @@ -1,22 +1,30 @@ -from __future__ import print_function +from __future__ import annotations import ctypes as _ctypes +import struct import sys -from typing import Any, BinaryIO, Optional +import types +from typing import Any, BinaryIO, Iterator from dissect.cstruct.exceptions import ResolveError +from dissect.cstruct.expression import Expression from dissect.cstruct.parser import CStyleParser, TokenParser from dissect.cstruct.types import ( LEB128, - Array, + ArrayMetaType, BaseType, - BytesInteger, - CharType, - PackedType, + Char, + Enum, + Field, + Flag, + Int, + MetaType, + Packed, Pointer, Structure, - VoidType, - WcharType, + Union, + Void, + Wchar, ) @@ -31,7 +39,7 @@ class cstruct: DEF_CSTYLE = 1 DEF_LEGACY = 2 - def __init__(self, endian: str = "<", pointer: Optional[str] = None): + def __init__(self, endian: str = "<", pointer: str | None = None): self.endian = endian self.consts = {} @@ -39,31 +47,31 @@ def __init__(self, endian: str = "<", pointer: Optional[str] = None): # fmt: off self.typedefs = { # Internal types - "int8": PackedType(self, "int8", 1, "b"), - "uint8": PackedType(self, "uint8", 1, "B"), - "int16": PackedType(self, "int16", 2, "h"), - "uint16": PackedType(self, "uint16", 2, "H"), - "int32": PackedType(self, "int32", 4, "i"), - "uint32": PackedType(self, "uint32", 4, "I"), - "int64": PackedType(self, "int64", 8, "q"), - "uint64": PackedType(self, "uint64", 8, "Q"), - "float16": PackedType(self, "float16", 2, "e"), - "float": PackedType(self, "float", 4, "f"), - "double": PackedType(self, "double", 8, "d"), - "char": CharType(self), - "wchar": WcharType(self), - - "int24": BytesInteger(self, "int24", 3, True, alignment=4), - "uint24": BytesInteger(self, "uint24", 3, False, alignment=4), - "int48": BytesInteger(self, "int48", 6, True, alignment=8), - "uint48": BytesInteger(self, "uint48", 6, False, alignment=8), - "int128": BytesInteger(self, "int128", 16, True, alignment=16), - "uint128": BytesInteger(self, "uint128", 16, False, alignment=16), - - "uleb128": LEB128(self, 'uleb128', None, False), - "ileb128": LEB128(self, 'ileb128', None, True), - - "void": VoidType(), + "int8": self._make_packed_type("int8", "b", int), + "uint8": self._make_packed_type("uint8", "B", int), + "int16": self._make_packed_type("int16", "h", int), + "uint16": self._make_packed_type("uint16", "H", int), + "int32": self._make_packed_type("int32", "i", int), + "uint32": self._make_packed_type("uint32", "I", int), + "int64": self._make_packed_type("int64", "q", int), + "uint64": self._make_packed_type("uint64", "Q", int), + "float16": self._make_packed_type("float16", "e", float), + "float": self._make_packed_type("float", "f", float), + "double": self._make_packed_type("double", "d", float), + "char": self._make_type("char", (Char,), 1), + "wchar": self._make_type("wchar", (Wchar,), 2), + + "int24": self._make_int_type("int24", 3, True, alignment=4), + "uint24": self._make_int_type("uint24", 3, False, alignment=4), + "int48": self._make_int_type("int48", 6, True, alignment=8), + "uint48": self._make_int_type("int48", 6, False, alignment=8), + "int128": self._make_int_type("int128", 16, True, alignment=16), + "uint128": self._make_int_type("uint128", 16, False, alignment=16), + + "uleb128": self._make_type("uleb128", (LEB128,), None, attrs={"signed": False}), + "ileb128": self._make_type("ileb128", (LEB128,), None, attrs={"signed": True}), + + "void": self._make_type("void", (Void,), 0), # Common C types not covered by internal types "signed char": "int8", @@ -121,6 +129,13 @@ def __init__(self, endian: str = "<", pointer: Optional[str] = None): "__int64": "int64", "__int128": "int128", + "unsigned __int8": "uint8", + "unsigned __int16": "uint16", + "unsigned __int32": "uint32", + "unsigned __int64": "uint64", + "unsigned __int128": "uint128", + "unsigned __int128": "uint128", + "wchar_t": "wchar", # GNU C types @@ -135,7 +150,6 @@ def __init__(self, endian: str = "<", pointer: Optional[str] = None): "uint32_t": "uint32", "uint64_t": "uint64", "uint128_t": "uint128", - "unsigned __int128": "uint128", # IDA types "_BYTE": "uint8", @@ -167,29 +181,30 @@ def __init__(self, endian: str = "<", pointer: Optional[str] = None): def __getattr__(self, attr: str) -> Any: try: - return self.resolve(self.typedefs[attr]) + return self.consts[attr] except KeyError: pass try: - return self.consts[attr] + return self.resolve(self.typedefs[attr]) except KeyError: pass raise AttributeError(f"Invalid attribute: {attr}") def _next_anonymous(self) -> str: - name = f"anonymous_{self._anonymous_count}" + name = f"__anonymous_{self._anonymous_count}__" self._anonymous_count += 1 return name - def addtype(self, name: str, type_: BaseType, replace: bool = False) -> None: + def add_type(self, name: str, type_: MetaType | str, replace: bool = False) -> None: """Add a type or type reference. + Only use this method when creating type aliases or adding already bound types. + Args: name: Name of the type to be added. - type_: The type to be added. Can be a str reference to another type - or a compatible type class. + type_: The type to be added. Can be a str reference to another type or a compatible type class. Raises: ValueError: If the type already exists. @@ -199,7 +214,26 @@ def addtype(self, name: str, type_: BaseType, replace: bool = False) -> None: self.typedefs[name] = type_ - def load(self, definition: str, deftype: int = None, **kwargs) -> "cstruct": + addtype = add_type + + def add_custom_type( + self, name: str, type_: MetaType, size: int | None = None, alignment: int | None = None, **kwargs + ) -> None: + """Add a custom type. + + Use this method to add custom types to this cstruct instance. This is largely a convenience method for + the internal :func:`_make_type` method, which binds a class to this cstruct instance. + + Args: + name: Name of the type to be added. + type_: The type to be added. + size: The size of the type. + alignment: The alignment of the type. + **kwargs: Additional attributes to add to the type. + """ + self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs)) + + def load(self, definition: str, deftype: int = None, **kwargs) -> cstruct: """Parse structures from the given definitions using the given definition type. Definitions can be parsed using different parsers. Currently, there's @@ -250,7 +284,7 @@ def read(self, name: str, stream: BinaryIO) -> Any: """ return self.resolve(name).read(stream) - def resolve(self, name: str) -> BaseType: + def resolve(self, name: str) -> MetaType: """Resolve a type name to get the actual type object. Types can be referenced using different names. When we want @@ -280,11 +314,103 @@ def resolve(self, name: str) -> BaseType: raise ResolveError(f"Recursion limit exceeded while resolving type {name}") + def _make_type( + self, + name: str, + bases: Iterator[object], + size: int | None, + *, + alignment: int | None = None, + attrs: dict[str, Any] = None, + ) -> type[BaseType]: + """Create a new type class bound to this cstruct instance. + + All types are created using this method. This method automatically binds the type to this cstruct instance. + """ + attrs = attrs or {} + attrs.update( + { + "cs": self, + "size": size, + "dynamic": size is None, + "alignment": alignment or size, + } + ) + return types.new_class(name, bases, {}, lambda ns: ns.update(attrs)) + + def _make_array(self, type_: MetaType, num_entries: int | Expression | None) -> ArrayMetaType: + null_terminated = num_entries is None + dynamic = isinstance(num_entries, Expression) or type_.dynamic + size = None if (null_terminated or dynamic) else (num_entries * type_.size) + name = f"{type_.__name__}[]" if null_terminated else f"{type_.__name__}[{num_entries}]" + + bases = (type_.ArrayType,) + + attrs = { + "type": type_, + "num_entries": num_entries, + "null_terminated": null_terminated, + } + + return self._make_type(name, bases, size, alignment=type_.alignment, attrs=attrs) + + def _make_int_type(self, name: str, size: int, signed: bool, *, alignment: int = None) -> type[Int]: + return self._make_type(name, (Int,), size, alignment=alignment, attrs={"signed": signed}) + + def _make_packed_type(self, name: str, packchar: str, base: type, *, alignment: int = None) -> type[Packed]: + return self._make_type( + name, + (base, Packed), + struct.calcsize(packchar), + alignment=alignment, + attrs={"packchar": packchar}, + ) + + def _make_enum(self, name: str, type_: MetaType, values: dict[str, int]) -> type[Enum]: + return Enum(self, name, type_, values) + + def _make_flag(self, name: str, type_: MetaType, values: dict[str, int]) -> type[Flag]: + return Flag(self, name, type_, values) + + def _make_pointer(self, target: MetaType) -> type[Pointer]: + return self._make_type( + f"{target.__name__}*", + (Pointer,), + self.pointer.size, + alignment=self.pointer.alignment, + attrs={"type": target}, + ) + + def _make_struct( + self, + name: str, + fields: list[Field], + *, + align: bool = False, + anonymous: bool = False, + base: type[Structure] = Structure, + ) -> type[Structure]: + return self._make_type( + name, + (base,), + None, + attrs={ + "fields": fields, + "__align__": align, + "__anonymous__": anonymous, + }, + ) + + def _make_union( + self, name: str, fields: list[Field], *, align: bool = False, anonymous: bool = False + ) -> type[Structure]: + return self._make_struct(name, fields, align=align, anonymous=anonymous, base=Union) + def ctypes(structure: Structure) -> _ctypes.Structure: """Create ctypes structures from cstruct structures.""" fields = [] - for field in structure.fields: + for field in structure.__fields__: t = ctypes_type(field.type) fields.append((field.name, t)) @@ -292,25 +418,38 @@ def ctypes(structure: Structure) -> _ctypes.Structure: return tt -def ctypes_type(type_: BaseType) -> Any: +def ctypes_type(type_: MetaType) -> Any: mapping = { - "I": _ctypes.c_ulong, - "i": _ctypes.c_long, "b": _ctypes.c_int8, + "B": _ctypes.c_uint8, + "h": _ctypes.c_int16, + "H": _ctypes.c_uint16, + "i": _ctypes.c_int32, + "I": _ctypes.c_uint32, + "q": _ctypes.c_int64, + "Q": _ctypes.c_uint64, + "f": _ctypes.c_float, + "d": _ctypes.c_double, } - if isinstance(type_, PackedType): + if issubclass(type_, Packed) and type_.packchar in mapping: return mapping[type_.packchar] - if isinstance(type_, CharType): + if issubclass(type_, Char): return _ctypes.c_char - if isinstance(type_, Array): + if issubclass(type_, Wchar): + return _ctypes.c_wchar + + if isinstance(type_, ArrayMetaType): subtype = ctypes_type(type_.type) - return subtype * type_.count + return subtype * type_.num_entries - if isinstance(type_, Pointer): + if issubclass(type_, Pointer): subtype = ctypes_type(type_.type) return _ctypes.POINTER(subtype) + if issubclass(type_, Structure): + return ctypes(type_) + raise NotImplementedError(f"Type not implemented: {type_.__class__.__name__}") diff --git a/dissect/cstruct/expression.py b/dissect/cstruct/expression.py index 4727f26..774e376 100644 --- a/dissect/cstruct/expression.py +++ b/dissect/cstruct/expression.py @@ -1,7 +1,7 @@ from __future__ import annotations import string -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError @@ -41,8 +41,8 @@ def operator(self, token: str) -> bool: def match( self, - func: Optional[Callable[[str], bool]] = None, - expected: Optional[str | set[str]] = None, + func: Callable[[str], bool] | None = None, + expected: str | None = None, consume: bool = True, append: bool = True, ) -> bool: @@ -154,7 +154,7 @@ def tokenize(self) -> list[str]: class Expression: """Expression parser for calculations in definitions.""" - operators = { + binary_operators = { "|": lambda a, b: a | b, "^": lambda a, b: a ^ b, "&": lambda a, b: a & b, @@ -165,6 +165,9 @@ class Expression: "*": lambda a, b: a * b, "/": lambda a, b: a // b, "%": lambda a, b: a % b, + } + + unary_operators = { "u": lambda a: -a, "~": lambda a: ~a, } @@ -206,26 +209,26 @@ def evaluate_exp(self) -> None: raise ExpressionParserError("Invalid expression: not enough operands") right = self.queue.pop(-1) - if operator in ("u", "~"): - res = self.operators[operator](right) + if operator in self.unary_operators: + res = self.unary_operators[operator](right) else: if len(self.queue) < 1: raise ExpressionParserError("Invalid expression: not enough operands") left = self.queue.pop(-1) - res = self.operators[operator](left, right) + res = self.binary_operators[operator](left, right) self.queue.append(res) def is_number(self, token: str) -> bool: return token.isnumeric() or (len(token) > 2 and token[0] == "0" and token[1] in ("x", "X", "b", "B", "o", "O")) - def evaluate(self, context: Optional[dict[str, int]] = None) -> int: + def evaluate(self, context: dict[str, int] | None = None) -> int: """Evaluates an expression using a Shunting-Yard implementation.""" self.stack = [] self.queue = [] - operators = set(self.operators.keys()) + operators = set(self.binary_operators.keys()) | set(self.unary_operators.keys()) context = context or {} tmp_expression = self.tokens @@ -249,9 +252,7 @@ def evaluate(self, context: Optional[dict[str, int]] = None) -> int: self.queue.append(int(context[current_token])) elif current_token in self.cstruct.consts: self.queue.append(int(self.cstruct.consts[current_token])) - elif current_token == "u": - self.stack.append(current_token) - elif current_token == "~": + elif current_token in self.unary_operators: self.stack.append(current_token) elif current_token == "sizeof": if len(tmp_expression) < i + 3 or (tmp_expression[i + 1] != "(" or tmp_expression[i + 3] != ")"): diff --git a/dissect/cstruct/parser.py b/dissect/cstruct/parser.py index a19deae..3685444 100644 --- a/dissect/cstruct/parser.py +++ b/dissect/cstruct/parser.py @@ -2,21 +2,16 @@ import ast import re -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING -from dissect.cstruct.compiler import Compiler -from dissect.cstruct.exceptions import ParserError -from dissect.cstruct.expression import Expression -from dissect.cstruct.types import ( - Array, - BaseType, - Enum, - Field, - Flag, - Pointer, - Structure, - Union, +from dissect.cstruct import compiler +from dissect.cstruct.exceptions import ( + ExpressionParserError, + ExpressionTokenizerError, + ParserError, ) +from dissect.cstruct.expression import Expression +from dissect.cstruct.types import ArrayMetaType, Field, MetaType if TYPE_CHECKING: from dissect.cstruct import cstruct @@ -51,7 +46,7 @@ class TokenParser(Parser): def __init__(self, cs: cstruct, compiled: bool = True, align: bool = False): super().__init__(cs) - self.compiler = Compiler(self.cstruct) if compiled else None + self.compiled = compiled self.align = align self.TOK = self._tokencollection() @@ -68,7 +63,7 @@ def _tokencollection() -> TokenCollection: "ENUM", ) TOK.add(r"(?<=})\s*(?P(?:[a-zA-Z0-9_]+\s*,\s*)+[a-zA-Z0-9_]+)\s*(?=;)", "DEFS") - TOK.add(r"(?P\*?[a-zA-Z0-9_]+)(?:\s*:\s*(?P\d+))?(?:\[(?P[^;\n]*)\])?\s*(?=;)", "NAME") + TOK.add(r"(?P\**?\s*[a-zA-Z0-9_]+)(?:\s*:\s*(?P\d+))?(?:\[(?P[^;\n]*)\])?\s*(?=;)", "NAME") TOK.add(r"[a-zA-Z_][a-zA-Z0-9_]*", "IDENTIFIER") TOK.add(r"[{}]", "BLOCK") TOK.add(r"\$(?P[^\s]+) = (?P{[^}]+})\w*[\r\n]+", "LOOKUP") @@ -95,10 +90,11 @@ def _constant(self, tokens: TokenConsumer) -> None: except (ValueError, SyntaxError): pass - try: - value = Expression(self.cstruct, value).evaluate() - except Exception: - pass + if isinstance(value, str): + try: + value = Expression(self.cstruct, value).evaluate() + except (ExpressionParserError, ExpressionTokenizerError): + pass self.cstruct.consts[match["name"]] = value @@ -118,7 +114,7 @@ def _enum(self, tokens: TokenConsumer) -> None: values = {} for line in d["values"].splitlines(): for v in line.split(","): - key, sep, val = v.partition("=") + key, _, val = v.partition("=") key = key.strip() val = val.strip() if not key: @@ -139,17 +135,13 @@ def _enum(self, tokens: TokenConsumer) -> None: if not d["type"]: d["type"] = "uint32" - enumcls = Enum - if enumtype == "flag": - enumcls = Flag - - enum = enumcls(self.cstruct, d["name"], self.cstruct.resolve(d["type"]), values) + factory = self.cstruct._make_flag if enumtype == "flag" else self.cstruct._make_enum - if not enum.name: - for name, value in enum.values.items(): - self.cstruct.consts[name] = enum(value) + enum = factory(d["name"] or "", self.cstruct.resolve(d["type"]), values) + if not enum.__name__: + self.cstruct.consts.update(enum.__members__) else: - self.cstruct.addtype(enum.name, enum) + self.cstruct.add_type(enum.__name__, enum) tokens.eol() @@ -170,17 +162,33 @@ def _typedef(self, tokens: TokenConsumer) -> None: type_, name, bits = self._parse_field_type(type_, name) if bits is not None: raise ParserError(f"line {self._lineno(tokens.previous)}: typedefs cannot have bitfields") - self.cstruct.addtype(name, type_) + self.cstruct.add_type(name, type_) def _struct(self, tokens: TokenConsumer, register: bool = False) -> None: stype = tokens.consume() + factory = self.cstruct._make_union if stype.value.startswith("union") else self.cstruct._make_struct + + st = None names = [] + registered = False + if tokens.next == self.TOK.IDENTIFIER: ident = tokens.consume() - names.append(ident.value) + if register: + # Pre-register an empty struct for self-referencing + # We update this instance later with the fields + st = factory(ident.value, [], align=self.align) + if self.compiled and "nocompile" not in tokens.flags: + st = compiler.compile(st) + self.cstruct.add_type(ident.value, st) + registered = True + else: + names.append(ident.value) if tokens.next == self.TOK.NAME: + # As part of a struct field + # struct type_name field_name; if not len(names): raise ParserError(f"line {self._lineno(tokens.next)}: unexpected anonymous struct") return self.cstruct.resolve(names[0]) @@ -198,30 +206,32 @@ def _struct(self, tokens: TokenConsumer, register: bool = False) -> None: field = self._parse_field(tokens) fields.append(field) + # All names from here on are from typedef's # Parsing names consumes the EOL token names.extend(self._names(tokens)) name = names[0] if names else None - if stype.value.startswith("union"): - class_ = Union - else: - class_ = Structure - is_anonymous = False - if not name: - is_anonymous = True - name = self.cstruct._next_anonymous() + if st is None: + is_anonymous = False + if not name: + is_anonymous = True + name = self.cstruct._next_anonymous() - st = class_(self.cstruct, name, fields, align=self.align, anonymous=is_anonymous) - if self.compiler and "nocompile" not in tokens.flags: - st = self.compiler.compile(st) + st = factory(name, fields, align=self.align, anonymous=is_anonymous) + if self.compiled and "nocompile" not in tokens.flags: + st = compiler.compile(st) + else: + st.__fields__.extend(fields) + st.commit() # This is pretty dirty if register: - if not names: + if not names and not registered: raise ParserError(f"line {self._lineno(stype)}: struct has no name") for name in names: - self.cstruct.addtype(name, st) + self.cstruct.add_type(name, st) + tokens.reset_flags() return st @@ -242,7 +252,7 @@ def _parse_field(self, tokens: TokenConsumer) -> Field: elif tokens.next == self.TOK.STRUCT: type_ = self._struct(tokens) if tokens.next != self.TOK.NAME: - return Field(type_.name, type_) + return Field(type_.__name__, type_) if tokens.next != self.TOK.NAME: raise ParserError(f"line {self._lineno(tokens.next)}: expected name") @@ -253,7 +263,7 @@ def _parse_field(self, tokens: TokenConsumer) -> Field: tokens.eol() return Field(name.strip(), type_, bits) - def _parse_field_type(self, type_: BaseType, name: str) -> tuple[BaseType, str, Optional[int]]: + def _parse_field_type(self, type_: MetaType, name: str) -> tuple[MetaType, str, int | None]: pattern = self.TOK.patterns[self.TOK.NAME] # Dirty trick because the regex expects a ; but we don't want it to be part of the value d = pattern.match(name + ";").groupdict() @@ -261,9 +271,9 @@ def _parse_field_type(self, type_: BaseType, name: str) -> tuple[BaseType, str, name = d["name"] count_expression = d["count"] - if name.startswith("*"): + while name.startswith("*"): name = name[1:] - type_ = Pointer(self.cstruct, type_) + type_ = self.cstruct._make_pointer(type_) if count_expression is not None: # Poor mans multi-dimensional array by abusing the eager regex match of count @@ -282,14 +292,14 @@ def _parse_field_type(self, type_: BaseType, name: str) -> tuple[BaseType, str, except Exception: pass - if isinstance(type_, Array) and count is None: + if isinstance(type_, ArrayMetaType) and count is None: raise ParserError("Depth required for multi-dimensional array") - type_ = Array(self.cstruct, type_, count) + type_ = self.cstruct._make_array(type_, count) - return type_, name, int(d["bits"]) if d["bits"] else None + return type_, name.strip(), int(d["bits"]) if d["bits"] else None - def _names(self, tokens: TokenConsumer) -> List[str]: + def _names(self, tokens: TokenConsumer) -> list[str]: names = [] while True: if tokens.next == self.TOK.EOL: @@ -301,7 +311,7 @@ def _names(self, tokens: TokenConsumer) -> List[str]: ntoken = tokens.consume() if ntoken == self.TOK.NAME: - names.append(ntoken.value) + names.append(ntoken.value.strip()) elif ntoken == self.TOK.DEFS: for name in ntoken.value.strip().split(","): names.append(name.strip()) @@ -410,9 +420,9 @@ def _enums(self, data: str) -> None: values = {} for line in d["values"].split("\n"): - line, sep, comment = line.partition("//") + line, _, _ = line.partition("//") for v in line.split(","): - key, sep, val = v.partition("=") + key, _, val = v.partition("=") key = key.strip() val = val.strip() if not key: @@ -433,15 +443,14 @@ def _enums(self, data: str) -> None: if not d["type"]: d["type"] = "uint32" - enumcls = Enum + factory = self.cstruct._make_enum if enumtype == "flag": - enumcls = Flag + factory = self.cstruct._make_flag - enum = enumcls(self.cstruct, d["name"], self.cstruct.resolve(d["type"]), values) - self.cstruct.addtype(enum.name, enum) + enum = factory(d["name"], self.cstruct.resolve(d["type"]), values) + self.cstruct.add_type(enum.__name__, enum) def _structs(self, data: str) -> None: - compiler = Compiler(self.cstruct) r = re.finditer( r"(#(?P(?:compile))\s+)?" r"((?Ptypedef)\s+)?" @@ -464,7 +473,7 @@ def _structs(self, data: str) -> None: if d["type"] == "struct": data = self._parse_fields(d["fields"][1:-1].strip()) - st = Structure(self.cstruct, name, data) + st = self.cstruct._make_struct(name, data) if d["flags"] == "compile" or self.compiled: st = compiler.compile(st) elif d["typedef"] == "typedef": @@ -473,12 +482,12 @@ def _structs(self, data: str) -> None: continue if d["name"]: - self.cstruct.addtype(d["name"], st) + self.cstruct.add_type(d["name"], st) if d["defs"]: for td in d["defs"].strip().split(","): td = td.strip() - self.cstruct.addtype(td, st) + self.cstruct.add_type(td, st) def _parse_fields(self, data: str) -> None: fields = re.finditer( @@ -508,18 +517,18 @@ def _parse_fields(self, data: str) -> None: except Exception: pass - type_ = Array(self.cstruct, type_, count) + type_ = self.cstruct._make_array(type_, count) if d["name"].startswith("*"): d["name"] = d["name"][1:] - type_ = Pointer(self.cstruct, type_) + type_ = self.cstruct._make_pointer(type_) field = Field(d["name"], type_, int(d["bits"]) if d["bits"] else None) result.append(field) return result - def _lookups(self, data: str, consts: Dict[str, int]) -> None: + def _lookups(self, data: str, consts: dict[str, int]) -> None: r = re.finditer(r"\$(?P[^\s]+) = ({[^}]+})\w*\n", data) for t in r: @@ -556,9 +565,9 @@ def __repr__(self): class TokenCollection: def __init__(self): - self.tokens: List[Token] = [] - self.lookup: Dict[str, str] = {} - self.patterns: Dict[str, re.Pattern] = {} + self.tokens: list[Token] = [] + self.lookup: dict[str, str] = {} + self.patterns: dict[str, re.Pattern] = {} def __getattr__(self, attr: str): try: @@ -578,12 +587,12 @@ def add(self, regex: str, name: str) -> None: class TokenConsumer: - def __init__(self, tokens: List[Token]): + def __init__(self, tokens: list[Token]): self.tokens = tokens self.flags = [] self.previous = None - def __contains__(self, token) -> bool: + def __contains__(self, token: Token) -> bool: return token in self.tokens def __len__(self) -> int: diff --git a/dissect/cstruct/types/__init__.py b/dissect/cstruct/types/__init__.py index 773f7c9..cf58fa1 100644 --- a/dissect/cstruct/types/__init__.py +++ b/dissect/cstruct/types/__init__.py @@ -1,34 +1,32 @@ -from dissect.cstruct.types.base import Array, BaseType, RawType -from dissect.cstruct.types.bytesinteger import BytesInteger -from dissect.cstruct.types.chartype import CharType -from dissect.cstruct.types.enum import Enum, EnumInstance -from dissect.cstruct.types.flag import Flag, FlagInstance -from dissect.cstruct.types.instance import Instance +from dissect.cstruct.types.base import Array, ArrayMetaType, BaseType, MetaType +from dissect.cstruct.types.char import Char, CharArray +from dissect.cstruct.types.enum import Enum +from dissect.cstruct.types.flag import Flag +from dissect.cstruct.types.int import Int from dissect.cstruct.types.leb128 import LEB128 -from dissect.cstruct.types.packedtype import PackedType -from dissect.cstruct.types.pointer import Pointer, PointerInstance +from dissect.cstruct.types.packed import Packed +from dissect.cstruct.types.pointer import Pointer from dissect.cstruct.types.structure import Field, Structure, Union -from dissect.cstruct.types.voidtype import VoidType -from dissect.cstruct.types.wchartype import WcharType +from dissect.cstruct.types.void import Void +from dissect.cstruct.types.wchar import Wchar, WcharArray __all__ = [ "Array", + "ArrayMetaType", "BaseType", - "BytesInteger", - "CharType", + "Char", + "CharArray", "Enum", - "EnumInstance", "Field", "Flag", - "FlagInstance", - "Instance", + "Int", "LEB128", - "PackedType", + "MetaType", + "Packed", "Pointer", - "PointerInstance", - "RawType", "Structure", "Union", - "VoidType", - "WcharType", + "Void", + "Wchar", + "WcharArray", ] diff --git a/dissect/cstruct/types/base.py b/dissect/cstruct/types/base.py index 9ad51bf..5ea6fdb 100644 --- a/dissect/cstruct/types/base.py +++ b/dissect/cstruct/types/base.py @@ -1,209 +1,279 @@ from __future__ import annotations +import functools from io import BytesIO -from typing import TYPE_CHECKING, Any, BinaryIO, List +from typing import TYPE_CHECKING, Any, BinaryIO, Callable from dissect.cstruct.exceptions import ArraySizeError from dissect.cstruct.expression import Expression if TYPE_CHECKING: - from dissect.cstruct import cstruct + from dissect.cstruct.cstruct import cstruct -class BaseType: - """Base class for cstruct type classes.""" +EOF = -0xE0F # Negative counts are illegal anyway, so abuse that for our EOF sentinel - def __init__(self, cstruct: cstruct): - self.cstruct = cstruct - def __getitem__(self, count: int) -> Array: - return Array(self.cstruct, self, count) +class MetaType(type): + """Base metaclass for cstruct type classes.""" - def __call__(self, *args, **kwargs) -> Any: - if len(args) > 0: - return self.read(*args, **kwargs) + cs: cstruct + """The cstruct instance this type class belongs to.""" + size: int | None + """The size of the type in bytes. Can be ``None`` for dynamic sized types.""" + dynamic: bool + """Whether or not the type is dynamically sized.""" + alignment: int | None + """The alignment of the type in bytes. A value of ``None`` will be treated as 1-byte aligned.""" - result = self.default() - if kwargs: - for k, v in kwargs.items(): - setattr(result, k, v) + # This must be the actual type, but since Array is a subclass of BaseType, we correct this at the bottom of the file + ArrayType: type[Array] = "Array" + """The array type for this type class.""" - return result + def __call__(cls, *args, **kwargs) -> MetaType | BaseType: + """Adds support for ``TypeClass(bytes | file-like object)`` parsing syntax.""" + # TODO: add support for Type(cs) API to create new bounded type classes, similar to the old API? + if len(args) == 1 and not isinstance(args[0], cls): + stream = args[0] - def reads(self, data: bytes) -> Any: - """Parse the given data according to the type that implements this class. + if hasattr(stream, "read"): + return cls._read(stream) - Args: - data: Byte string to parse. + if issubclass(cls, bytes) and isinstance(stream, bytes) and len(stream) == cls.size: + # Shortcut for char/bytes type + return type.__call__(cls, *args, **kwargs) - Returns: - The parsed value of this type. - """ + if isinstance(stream, (bytes, memoryview, bytearray)): + return cls.reads(stream) + + return type.__call__(cls, *args, **kwargs) + + def __getitem__(cls, num_entries: int | Expression | None) -> ArrayMetaType: + """Create a new array with the given number of entries.""" + return cls.cs._make_array(cls, num_entries) + + def __len__(cls) -> int: + """Return the byte size of the type.""" + if cls.size is None: + raise TypeError("Dynamic size") - return self._read(BytesIO(data)) + return cls.size - def dumps(self, data: Any) -> bytes: - """Dump the given data according to the type that implements this class. + def default(cls) -> BaseType: + """Return the default value of this type.""" + return cls() + + def reads(cls, data: bytes) -> BaseType: + """Parse the given data from a bytes-like object. Args: - data: Data to dump. + data: Bytes-like object to parse. Returns: - The resulting bytes. - - Raises: - ArraySizeError: Raised when ``len(data)`` does not match the size of a statically sized array field. + The parsed value of this type. """ - out = BytesIO() - self._write(out, data) - return out.getvalue() + return cls._read(BytesIO(data)) - def read(self, obj: BinaryIO, *args, **kwargs) -> Any: - """Parse the given data according to the type that implements this class. + def read(cls, obj: BinaryIO | bytes) -> BaseType: + """Parse the given data. Args: - obj: Data to parse. Can be a (byte) string or a file-like object. + obj: Data to parse. Can be a bytes-like object or a file-like object. Returns: The parsed value of this type. """ if isinstance(obj, (bytes, memoryview, bytearray)): - return self.reads(obj) + return cls.reads(obj) - return self._read(obj) + return cls._read(obj) - def write(self, stream: BinaryIO, data: Any) -> int: - """Write the given data to a writable file-like object according to the - type that implements this class. + def write(cls, stream: BinaryIO, value: Any) -> int: + """Write a value to a writable file-like object. Args: - stream: Writable file-like object to write to. - data: Data to write. + stream: File-like objects that supports writing. + value: Value to write. Returns: The amount of bytes written. + """ + return cls._write(stream, value) + + def dumps(cls, value: Any) -> bytes: + """Dump a value to a byte string. + + Args: + value: Value to dump. - Raises: - ArraySizeError: Raised when ``len(data)`` does not match the size of a statically sized array field. + Returns: + The raw bytes of this type. """ - return self._write(stream, data) + out = BytesIO() + cls._write(out, value) + return out.getvalue() - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> Any: - raise NotImplementedError() + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType: + """Internal function for reading value. - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> List[Any]: - return [self._read(stream, context) for _ in range(count)] + Must be implemented per type. - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[Any]: + Args: + stream: The stream to read from. + context: Optional reading context. + """ raise NotImplementedError() - def _write(self, stream: BinaryIO, data: Any) -> int: - raise NotImplementedError() + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[BaseType]: + """Internal function for reading array values. - def _write_array(self, stream: BinaryIO, data: Any) -> int: - num = 0 - for i in data: - num += self._write(stream, i) + Allows type implementations to do optimized reading for their type. - return num + Args: + stream: The stream to read from. + count: The amount of values to read. + context: Optional reading context. + """ + if count == EOF: + result = [] + while True: + try: + result.append(cls._read(stream, context)) + except EOFError: + break + return result + + return [cls._read(stream, context) for _ in range(count)] + + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[BaseType]: + """Internal function for reading null-terminated data. + + "Null" is type specific, so must be implemented per type. - def _write_0(self, stream: BinaryIO, data: Any) -> int: + Args: + stream: The stream to read from. + context: Optional reading context. + """ raise NotImplementedError() - def default(self) -> Any: - """Return a default value of this type.""" + def _write(cls, stream: BinaryIO, data: Any) -> int: raise NotImplementedError() - def default_array(self, count: int) -> List[Any]: - """Return a default array of this type.""" - return [self.default() for _ in range(count)] + def _write_array(cls, stream: BinaryIO, array: list[BaseType]) -> int: + """Internal function for writing arrays. + Allows type implementations to do optimized writing for their type. -class Array(BaseType): - """Implements a fixed or dynamically sized array type. + Args: + stream: The stream to read from. + array: The array to write. + """ + return sum(cls._write(stream, entry) for entry in array) - Example: - When using the default C-style parser, the following syntax is supported: + def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: + """Internal function for writing null-terminated arrays. - x[3] -> 3 -> static length. - x[] -> None -> null-terminated. - x[expr] -> expr -> dynamic length. + Allows type implementations to do optimized writing for their type. + + Args: + stream: The stream to read from. + array: The array to write. + """ + return cls._write_array(stream, array + [cls()]) + + +class _overload: + """Descriptor to use on the ``write`` and ``dumps`` methods on cstruct types. + + Allows for calling these methods on both the type and instance. + + Example: + >>> int32.dumps(123) + b'\\x7b\\x00\\x00\\x00' + >>> int32(123).dumps() + b'\\x7b\\x00\\x00\\x00' """ - def __init__(self, cstruct: cstruct, type_: BaseType, count: int): - self.type = type_ - self.count = count - self.null_terminated = self.count is None - self.dynamic = isinstance(self.count, Expression) - self.alignment = type_.alignment - super().__init__(cstruct) + def __init__(self, func: Callable[[Any], Any]) -> None: + self.func = func + + def __get__(self, instance: BaseType | None, owner: MetaType) -> Callable[[Any], bytes]: + if instance is None: + return functools.partial(self.func, owner) + else: + return functools.partial(self.func, instance.__class__, value=instance) - def __repr__(self) -> str: - if self.null_terminated: - return f"{self.type}[]" - return f"{self.type}[{self.count}]" +class BaseType(metaclass=MetaType): + """Base class for cstruct type classes.""" + + dumps = _overload(MetaType.dumps) + write = _overload(MetaType.write) def __len__(self) -> int: - if self.dynamic or self.null_terminated: + """Return the byte size of the type.""" + if self.__class__.size is None: raise TypeError("Dynamic size") - return len(self.type) * self.count - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[Any]: - if self.null_terminated: - return self.type._read_0(stream, context) + return self.__class__.size - if self.dynamic: - count = self.count.evaluate(context) - else: - count = self.count - return self.type._read_array(stream, max(0, count), context) +class ArrayMetaType(MetaType): + """Base metaclass for array-like types.""" - def _write(self, stream: BinaryIO, data: List[Any]) -> int: - if self.null_terminated: - return self.type._write_0(stream, data) + type: MetaType + num_entries: int | Expression | None + null_terminated: bool - if not self.dynamic and self.count != (actual_size := len(data)): - raise ArraySizeError(f"Expected static array size {self.count}, got {actual_size} instead.") + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array: + if cls.null_terminated: + return cls.type._read_0(stream, context) - return self.type._write_array(stream, data) + if isinstance(cls.num_entries, int): + num = max(0, cls.num_entries) + elif cls.num_entries is None: + num = EOF + elif isinstance(cls.num_entries, Expression): + try: + num = max(0, cls.num_entries.evaluate(context)) + except Exception: + if cls.num_entries.expression != "EOF": + raise + num = EOF - def default(self) -> List[Any]: - count = 0 if self.dynamic or self.null_terminated else self.count - return self.type.default_array(count) + return cls.type._read_array(stream, num, context) + def default(cls) -> BaseType: + return type.__call__( + cls, [cls.type.default() for _ in range(0 if cls.dynamic or cls.null_terminated else cls.num_entries)] + ) -class RawType(BaseType): - """Base class for raw types that have a name and size.""" - def __init__(self, cstruct: cstruct, name: str = None, size: int = 0, alignment: int = None): - self.name = name - self.size = size - self.alignment = alignment or size - super().__init__(cstruct) +class Array(list, BaseType, metaclass=ArrayMetaType): + """Implements a fixed or dynamically sized array type. - def __len__(self) -> int: - return self.size + Example: + When using the default C-style parser, the following syntax is supported: - def __repr__(self) -> str: - if self.name: - return self.name + x[3] -> 3 -> static length. + x[] -> None -> null-terminated. + x[expr] -> expr -> dynamic length. + """ - return BaseType.__repr__(self) + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Array: + return cls(ArrayMetaType._read(cls, stream, context)) - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> Any: - raise NotImplementedError() + @classmethod + def _write(cls, stream: BinaryIO, data: list[Any]) -> int: + if cls.null_terminated: + return cls.type._write_0(stream, data) - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[Any]: - raise NotImplementedError() + if not cls.dynamic and cls.num_entries != (actual_size := len(data)): + raise ArraySizeError(f"Expected static array size {cls.num_entries}, got {actual_size} instead.") - def _write(self, stream: BinaryIO, data: Any) -> int: - raise NotImplementedError() + return cls.type._write_array(stream, data) - def _write_0(self, stream: BinaryIO, data: List[Any]) -> int: - raise NotImplementedError() - def default(self) -> Any: - raise NotImplementedError() +# As mentioned in the BaseType class, we correctly set the type here +MetaType.ArrayType = Array diff --git a/dissect/cstruct/types/bytesinteger.py b/dissect/cstruct/types/bytesinteger.py deleted file mode 100644 index a413e6c..0000000 --- a/dissect/cstruct/types/bytesinteger.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, BinaryIO, List - -from dissect.cstruct.types import RawType - -if TYPE_CHECKING: - from dissect.cstruct import cstruct - - -class BytesInteger(RawType): - """Implements an integer type that can span an arbitrary amount of bytes.""" - - def __init__(self, cstruct: cstruct, name: str, size: int, signed: bool, alignment: int = None): - self.signed = signed - super().__init__(cstruct, name, size, alignment) - - @staticmethod - def parse(buf: BinaryIO, size: int, count: int, signed: bool, endian: str) -> List[int]: - nums = [] - - for c in range(count): - num = 0 - data = buf[c * size : (c + 1) * size] - if endian == "<": # little-endian (LE) - data = b"".join(data[i : i + 1] for i in reversed(range(len(data)))) - - ints = list(data) - for i in ints: - num = (num << 8) | i - - if signed and (num & (1 << (size * 8 - 1))): - bias = 1 << (size * 8 - 1) - num -= bias * 2 - - nums.append(num) - - return nums - - @staticmethod - def pack(data: List[int], size: int, endian: str, signed: bool) -> bytes: - buf = [] - - bits = size * 8 - unsigned_min = 0 - unsigned_max = (2**bits) - 1 - signed_min = -(2 ** (bits - 1)) - signed_max = (2 ** (bits - 1)) - 1 - - for i in data: - if signed and (i < signed_min or i > signed_max): - raise OverflowError(f"{i} exceeds bounds for signed {bits} bits BytesInteger") - elif not signed and (i < unsigned_min or i > unsigned_max): - raise OverflowError(f"{i} exceeds bounds for unsigned {bits} bits BytesInteger") - - num = int(i) - if num < 0: - num += 1 << (size * 8) - - d = [b"\x00"] * size - i = size - 1 - - while i >= 0: - b = num & 255 - d[i] = bytes((b,)) - num >>= 8 - i -= 1 - - if endian == "<": - d = b"".join(d[i : i + 1][0] for i in reversed(range(len(d)))) - else: - d = b"".join(d) - - buf.append(d) - - return b"".join(buf) - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> int: - return self.parse(stream.read(self.size * 1), self.size, 1, self.signed, self.cstruct.endian)[0] - - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> List[int]: - return self.parse( - stream.read(self.size * count), - self.size, - count, - self.signed, - self.cstruct.endian, - ) - - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[int]: - result = [] - - while True: - v = self._read(stream, context) - if v == 0: - break - - result.append(v) - - return result - - def _write(self, stream: BinaryIO, data: int) -> int: - return stream.write(self.pack([data], self.size, self.cstruct.endian, self.signed)) - - def _write_array(self, stream: BinaryIO, data: List[int]) -> int: - return stream.write(self.pack(data, self.size, self.cstruct.endian, self.signed)) - - def _write_0(self, stream: BinaryIO, data: List[int]) -> int: - return self._write_array(stream, data + [0]) - - def default(self) -> int: - return 0 - - def default_array(self, count: int) -> List[int]: - return [0] * count diff --git a/dissect/cstruct/types/char.py b/dissect/cstruct/types/char.py new file mode 100644 index 0000000..cec72c1 --- /dev/null +++ b/dissect/cstruct/types/char.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any, BinaryIO + +from dissect.cstruct.types.base import EOF, ArrayMetaType, BaseType + + +class CharArray(bytes, BaseType, metaclass=ArrayMetaType): + """Character array type for reading and writing byte strings.""" + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> CharArray: + return type.__call__(cls, ArrayMetaType._read(cls, stream, context)) + + @classmethod + def _write(cls, stream: BinaryIO, data: bytes) -> int: + if isinstance(data, list) and data and isinstance(data[0], int): + data = bytes(data) + + elif isinstance(data, str): + data = data.encode("latin-1") + + if cls.null_terminated: + return stream.write(data + b"\x00") + return stream.write(data) + + @classmethod + def default(cls) -> CharArray: + return type.__call__(cls, b"\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries)) + + +class Char(bytes, BaseType): + """Character type for reading and writing bytes.""" + + ArrayType = CharArray + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char: + return cls._read_array(stream, 1, context) + + @classmethod + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Char: + if count == 0: + return type.__call__(cls, b"") + + data = stream.read(-1 if count == EOF else count) + if count != EOF and len(data) != count: + raise EOFError(f"Read {len(data)} bytes, but expected {count}") + + return type.__call__(cls, data) + + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char: + buf = [] + while True: + byte = stream.read(1) + if byte == b"": + raise EOFError("Read 0 bytes, but expected 1") + + if byte == b"\x00": + break + + buf.append(byte) + + return type.__call__(cls, b"".join(buf)) + + @classmethod + def _write(cls, stream: BinaryIO, data: bytes | int | str) -> int: + if isinstance(data, int): + data = chr(data) + + if isinstance(data, str): + data = data.encode("latin-1") + + return stream.write(data) + + @classmethod + def default(cls) -> Char: + return type.__call__(cls, b"\x00") diff --git a/dissect/cstruct/types/chartype.py b/dissect/cstruct/types/chartype.py deleted file mode 100644 index 5c71b7d..0000000 --- a/dissect/cstruct/types/chartype.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, BinaryIO - -from dissect.cstruct.types import RawType - -if TYPE_CHECKING: - from dissect.cstruct import cstruct - - -class CharType(RawType): - """Implements a character type that can properly handle strings.""" - - def __init__(self, cstruct: cstruct): - super().__init__(cstruct, "char", size=1, alignment=1) - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> bytes: - return stream.read(1) - - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> bytes: - if count == 0: - return b"" - - return stream.read(count) - - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> bytes: - byte_array = [] - while True: - bytes_stream = stream.read(1) - if bytes_stream == b"": - raise EOFError() - - if bytes_stream == b"\x00": - break - - byte_array.append(bytes_stream) - - return b"".join(byte_array) - - def _write(self, stream: BinaryIO, data: bytes) -> int: - if isinstance(data, int): - data = chr(data) - - if isinstance(data, str): - data = data.encode("latin-1") - - return stream.write(data) - - def _write_array(self, stream: BinaryIO, data: bytes) -> int: - return self._write(stream, data) - - def _write_0(self, stream: BinaryIO, data: bytes) -> int: - return self._write(stream, data + b"\x00") - - def default(self) -> bytes: - return b"\x00" - - def default_array(self, count: int) -> bytes: - return b"\x00" * count diff --git a/dissect/cstruct/types/enum.py b/dissect/cstruct/types/enum.py index f274458..fcb2f09 100644 --- a/dissect/cstruct/types/enum.py +++ b/dissect/cstruct/types/enum.py @@ -1,124 +1,183 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, BinaryIO, Dict, List, Union +import sys +from enum import EnumMeta, IntEnum, IntFlag +from typing import TYPE_CHECKING, Any, BinaryIO -from dissect.cstruct.types import BaseType, RawType +from dissect.cstruct.types.base import Array, BaseType, MetaType if TYPE_CHECKING: - from dissect.cstruct import cstruct + from dissect.cstruct.cstruct import cstruct -class Enum(RawType): - """Implements an Enum type. +PY_311 = sys.version_info >= (3, 11, 0) - Enums can be made using any type. The API for accessing enums and their - values is very similar to Python 3 native enums. - Example: - When using the default C-style parser, the following syntax is supported: +class EnumMetaType(EnumMeta, MetaType): + type: MetaType - enum [: ] { - - }; + def __call__( + cls, + value: cstruct | int | BinaryIO | bytes = None, + name: str | None = None, + type_: MetaType | None = None, + *args, + **kwargs, + ) -> EnumMetaType: + if name is None: + if value is None: + value = cls.type() - For example, an enum that has A=1, B=5 and C=6 could be written like so: + if not isinstance(value, int): + # value is a parsable value + value = cls.type(value) - enum Test : uint16 { - A, B=5, C - }; - """ + return super().__call__(value) - def __init__(self, cstruct: cstruct, name: str, type_: BaseType, values: Dict[str, int]): - self.type = type_ - self.values = values - self.reverse = {} + cs = value + if not issubclass(type_, int): + raise TypeError("Enum can only be created from int type") - for k, v in values.items(): - self.reverse[v] = k + enum_cls = super().__call__(name, *args, **kwargs) + enum_cls.cs = cs + enum_cls.type = type_ + enum_cls.size = type_.size + enum_cls.dynamic = type_.dynamic + enum_cls.alignment = type_.alignment - super().__init__(cstruct, name, len(self.type), self.type.alignment) + _fix_alias_members(enum_cls) - def __call__(self, value: Union[int, BinaryIO]) -> EnumInstance: - if isinstance(value, int): - return EnumInstance(self, value) - return super().__call__(value) + return enum_cls - def __getitem__(self, attr: str) -> EnumInstance: - return self(self.values[attr]) + def __getitem__(cls, name: str | int) -> Enum | Array: + if isinstance(name, str): + return super().__getitem__(name) + return MetaType.__getitem__(cls, name) - def __getattr__(self, attr: str) -> EnumInstance: - try: - return self(self.values[attr]) - except KeyError: - raise AttributeError(attr) + __len__ = MetaType.__len__ - def __contains__(self, attr: str) -> bool: - return attr in self.values + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Enum: + return cls(cls.type._read(stream, context)) - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> EnumInstance: - v = self.type._read(stream, context) - return self(v) + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Enum]: + return list(map(cls, cls.type._read_array(stream, count, context))) - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> List[EnumInstance]: - return list(map(self, self.type._read_array(stream, count, context))) + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Enum]: + return list(map(cls, cls.type._read_0(stream, context))) - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[EnumInstance]: - return list(map(self, self.type._read_0(stream, context))) + def _write(cls, stream: BinaryIO, data: Enum) -> int: + return cls.type._write(stream, data.value) - def _write(self, stream: BinaryIO, data: Union[int, EnumInstance]) -> int: - data = data.value if isinstance(data, EnumInstance) else data - return self.type._write(stream, data) + def _write_array(cls, stream: BinaryIO, array: list[Enum]) -> int: + data = [entry.value if isinstance(entry, Enum) else entry for entry in array] + return cls.type._write_array(stream, data) - def _write_array(self, stream: BinaryIO, data: List[Union[int, EnumInstance]]) -> int: - data = [d.value if isinstance(d, EnumInstance) else d for d in data] - return self.type._write_array(stream, data) + def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: + data = [entry.value if isinstance(entry, Enum) else entry for entry in array] + return cls._write_array(stream, data + [cls.type()]) - def _write_0(self, stream: BinaryIO, data: List[Union[int, EnumInstance]]) -> int: - data = [d.value if isinstance(d, EnumInstance) else d for d in data] - return self.type._write_0(stream, data) - def default(self) -> EnumInstance: - return self(0) +def _fix_alias_members(cls: type[Enum]) -> None: + # Emulate aenum NoAlias behaviour + # https://github.com/ethanfurman/aenum/blob/master/aenum/doc/aenum.rst + if len(cls._member_names_) == len(cls._member_map_): + return - def default_array(self, count: int) -> List[EnumInstance]: - return [self.default() for _ in range(count)] + for name, member in cls._member_map_.items(): + if name != member.name: + new_member = int.__new__(cls, member.value) + new_member._name_ = name + new_member._value_ = member.value + type.__setattr__(cls, name, new_member) + cls._member_names_.append(name) + cls._member_map_[name] = new_member + cls._value2member_map_[member.value] = new_member -class EnumInstance: - """Implements a value instance of an Enum""" - def __init__(self, enum: Enum, value: int): - self.enum = enum - self.value = value +class Enum(BaseType, IntEnum, metaclass=EnumMetaType): + """Enum type supercharged with cstruct functionality. - def __eq__(self, value: Union[int, EnumInstance]) -> bool: - if isinstance(value, EnumInstance) and value.enum is not self.enum: - return False + Enums are (mostly) compatible with the Python 3 standard library ``IntEnum`` with some notable differences: + - Duplicate members are their own unique member instead of being an alias + - Non-existing values are allowed and handled similarly to ``IntFlag``: ```` + - Enum members are only considered equal if the enum class is the same - if hasattr(value, "value"): - value = value.value + Enums can be made using any integer type. - return self.value == value + Example: + When using the default C-style parser, the following syntax is supported: - def __ne__(self, value: Union[int, EnumInstance]) -> bool: - return self.__eq__(value) is False + enum [: ] { + + }; - def __hash__(self) -> int: - return hash((self.enum, self.value)) + For example, an enum that has A=1, B=5 and C=6 could be written like so: + + enum Test : uint16 { + A, B=5, C + }; + """ - def __str__(self) -> str: - return f"{self.enum.name}.{self.name}" + if PY_311: + + def __repr__(self) -> str: + # Use the IntFlag repr as a base since it handles unknown values the way we want it + # I.e. instead of + result = IntFlag.__repr__(self) + if not self.__class__.__name__: + # Deal with anonymous enums by stripping off the first bit + # I.e. <.RED: 1> -> + result = f"<{result[2:]}" + return result + + def __str__(self) -> str: + # We differentiate with standard Python enums in that we use a more descriptive str representation + # Standard Python enums just use the integer value as str, we use EnumName.ValueName + # In case of anonymous enums, we just use the ValueName + # In case of unknown members, we use the integer value (in combination with the EnumName if there is one) + base = f"{self.__class__.__name__}." if self.__class__.__name__ else "" + value = self.name if self.name is not None else str(self.value) + return f"{base}{value}" + + else: + + def __repr__(self) -> str: + name = self.__class__.__name__ + if self._name_ is not None: + if name: + name += "." + name += self._name_ + return f"<{name}: {self._value_!r}>" + + def __str__(self) -> str: + base = f"{self.__class__.__name__}." if self.__class__.__name__ else "" + value = self._name_ if self._name_ is not None else str(self._value_) + return f"{base}{value}" + + def __eq__(self, other: int | Enum) -> bool: + if isinstance(other, Enum) and other.__class__ is not self.__class__: + return False - def __int__(self) -> int: - return self.value + # Python <= 3.10 compatibility + if isinstance(other, Enum): + other = other.value - def __repr__(self) -> str: - name = f"{self.enum.name}.{self.name}" if self.enum.name else self.name - return f"<{name}: {self.value}>" + return self.value == other - @property - def name(self) -> str: - if self.value not in self.enum.reverse: - return f"{self.enum.name}_{self.value}" + def __ne__(self, value: int | Enum) -> bool: + return not self.__eq__(value) - return self.enum.reverse[self.value] + def __hash__(self) -> int: + return hash((self.__class__, self.name, self.value)) + + @classmethod + def _missing_(cls, value: int) -> Enum: + # Emulate FlagBoundary.KEEP for enum (allow values other than the defined members) + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + new_member = int.__new__(cls, value) + new_member._name_ = None + new_member._value_ = value + pseudo_member = cls._value2member_map_.setdefault(value, new_member) + return pseudo_member diff --git a/dissect/cstruct/types/flag.py b/dissect/cstruct/types/flag.py index 1c4617b..58bae27 100644 --- a/dissect/cstruct/types/flag.py +++ b/dissect/cstruct/types/flag.py @@ -1,15 +1,18 @@ from __future__ import annotations -from typing import BinaryIO, List, Tuple, Union +from enum import IntFlag -from dissect.cstruct.types import Enum, EnumInstance +from dissect.cstruct.types.base import BaseType +from dissect.cstruct.types.enum import PY_311, EnumMetaType -class Flag(Enum): - """Implements a Flag type. +class Flag(BaseType, IntFlag, metaclass=EnumMetaType): + """Flag type supercharged with cstruct functionality. - Flags can be made using any type. The API for accessing flags and their - values is very similar to Python 3 native flags. + Flags are (mostly) compatible with the Python 3 standard library ``IntFlag`` with some notable differences: + - Flag members are only considered equal if the flag class is the same + + Flags can be made using any integer type. Example: When using the default C-style parser, the following syntax is supported: @@ -25,82 +28,45 @@ class Flag(Enum): }; """ - def __call__(self, value: Union[int, BinaryIO]) -> FlagInstance: - if isinstance(value, int): - return FlagInstance(self, value) - - return super().__call__(value) - - -class FlagInstance(EnumInstance): - """Implements a value instance of a Flag""" - - def __bool__(self): - return bool(self.value) - - __nonzero__ = __bool__ - - def __or__(self, other: Union[int, FlagInstance]) -> FlagInstance: - if hasattr(other, "value"): - other = other.value - - return self.__class__(self.enum, self.value | other) - - def __and__(self, other: Union[int, FlagInstance]) -> FlagInstance: - if hasattr(other, "value"): - other = other.value - - return self.__class__(self.enum, self.value & other) - - def __xor__(self, other: Union[int, FlagInstance]) -> FlagInstance: - if hasattr(other, "value"): - other = other.value - - return self.__class__(self.enum, self.value ^ other) - - __ror__ = __or__ - __rand__ = __and__ - __rxor__ = __xor__ - - def __invert__(self) -> FlagInstance: - return self.__class__(self.enum, ~self.value) - - def __str__(self) -> str: - if self.name is not None: - return f"{self.enum.name}.{self.name}" - - members, _ = self.decompose() - members_str = "|".join([str(name or value) for name, value in members]) - return f"{self.enum.name}.{members_str}" - def __repr__(self) -> str: - base_name = f"{self.enum.name}." if self.enum.name else "" - - if self.name is not None: - return f"<{base_name}{self.name}: {self.value}>" - - members, _ = self.decompose() - members_str = "|".join([str(name or value) for name, value in members]) - return f"<{base_name}{members_str}: {self.value}>" - - @property - def name(self) -> str: - return self.enum.reverse.get(self.value, None) - - def decompose(self) -> Tuple[List[str], int]: - members = [] - not_covered = self.value - - for name, value in self.enum.values.items(): - if value and ((value & self.value) == value): - members.append((name, value)) - not_covered &= ~value + result = super().__repr__() + if not self.__class__.__name__: + # Deal with anonymous flags by stripping off the first bit + # I.e. <.RED: 1> -> + result = f"<{result[2:]}" + return result + + if PY_311: + + def __str__(self) -> str: + # We differentiate with standard Python flags in that we use a more descriptive str representation + # Standard Python flags just use the integer value as str, we use FlagName.ValueName + # In case of anonymous flags, we just use the ValueName + base = f"{self.__class__.__name__}." if self.__class__.__name__ else "" + return f"{base}{self.name}" + + else: + + def __str__(self) -> str: + result = IntFlag.__str__(self) + if not self.__class__.__name__: + # Deal with anonymous flags + # I.e. .RED -> RED + result = result[1:] + return result + + def __eq__(self, other: int | Flag) -> bool: + if isinstance(other, Flag) and other.__class__ is not self.__class__: + return False + + # Python <= 3.10 compatibility + if isinstance(other, Flag): + other = other.value - if not members: - members.append((None, self.value)) + return self.value == other - members.sort(key=lambda m: m[0], reverse=True) - if len(members) > 1 and members[0][1] == self.value: - members.pop(0) + def __ne__(self, value: int | Flag) -> bool: + return not self.__eq__(value) - return members, not_covered + def __hash__(self) -> int: + return hash((self.__class__, self.name, self.value)) diff --git a/dissect/cstruct/types/instance.py b/dissect/cstruct/types/instance.py deleted file mode 100644 index 0d144f4..0000000 --- a/dissect/cstruct/types/instance.py +++ /dev/null @@ -1,68 +0,0 @@ -from io import BytesIO -from typing import Any, BinaryIO, Dict - -from dissect.cstruct.types import BaseType - - -class Instance: - """Holds parsed structure data.""" - - __slots__ = ("_type", "_values", "_sizes") - - def __init__(self, type_: BaseType, values: Dict[str, Any], sizes: Dict[str, int] = None): - # Done in this manner to check if the attr is in the lookup - object.__setattr__(self, "_type", type_) - object.__setattr__(self, "_values", values) - object.__setattr__(self, "_sizes", sizes) - - def __getattr__(self, attr: str) -> Any: - try: - return self._values[attr] - except KeyError: - raise AttributeError(f"Invalid attribute: {attr}") - - def __setattr__(self, attr: str, value: Any) -> None: - if attr not in self._type.lookup: - raise AttributeError(f"Invalid attribute: {attr}") - - self._values[attr] = value - - def __getitem__(self, item: str) -> Any: - return self._values[item] - - def __contains__(self, attr: str) -> bool: - return attr in self._values - - def __repr__(self) -> str: - values = ", ".join([f"{k}={hex(v) if isinstance(v, int) else repr(v)}" for k, v in self._values.items()]) - return f"<{self._type.name} {values}>" - - def __len__(self) -> int: - return len(self.dumps()) - - def __bytes__(self) -> bytes: - return self.dumps() - - def _size(self, field: str) -> int: - return self._sizes[field] - - def write(self, stream: BinaryIO) -> int: - """Write this structure to a writable file-like object. - - Args: - fh: File-like objects that supports writing. - - Returns: - The amount of bytes written. - """ - return self._type.write(stream, self) - - def dumps(self) -> bytes: - """Dump this structure to a byte string. - - Returns: - The raw bytes of this structure. - """ - s = BytesIO() - self.write(s) - return s.getvalue() diff --git a/dissect/cstruct/types/int.py b/dissect/cstruct/types/int.py new file mode 100644 index 0000000..b1bc29c --- /dev/null +++ b/dissect/cstruct/types/int.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import Any, BinaryIO + +from dissect.cstruct.types.base import BaseType +from dissect.cstruct.utils import ENDIANNESS_MAP + + +class Int(int, BaseType): + """Integer type that can span an arbitrary amount of bytes.""" + + signed: bool + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int: + data = stream.read(cls.size) + + if len(data) != cls.size: + raise EOFError(f"Read {len(data)} bytes, but expected {cls.size}") + + return cls.from_bytes(data, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed) + + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Int: + result = [] + + while True: + if (value := cls._read(stream, context)) == 0: + break + + result.append(value) + + return result + + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> int: + return stream.write(data.to_bytes(cls.size, ENDIANNESS_MAP[cls.cs.endian], signed=cls.signed)) diff --git a/dissect/cstruct/types/leb128.py b/dissect/cstruct/types/leb128.py index 64d992b..9f0a398 100644 --- a/dissect/cstruct/types/leb128.py +++ b/dissect/cstruct/types/leb128.py @@ -1,14 +1,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, BinaryIO +from typing import Any, BinaryIO -from dissect.cstruct.types.base import RawType +from dissect.cstruct.types.base import BaseType -if TYPE_CHECKING: - from dissect.cstruct import cstruct - -class LEB128(RawType): +class LEB128(int, BaseType): """Variable-length code compression to store an arbitrarily large integer in a small number of bytes. See https://en.wikipedia.org/wiki/LEB128 for more information and an explanation of the algorithm. @@ -16,11 +13,8 @@ class LEB128(RawType): signed: bool - def __init__(self, cstruct: cstruct, name: str, size: int, signed: bool, alignment: int = 1): - self.signed = signed - super().__init__(cstruct, name, size, alignment) - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: result = 0 shift = 0 while True: @@ -34,26 +28,28 @@ def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: if (b & 0x80) == 0: break - if self.signed: + if cls.signed: if b & 0x40 != 0: result |= ~0 << shift - return result + return cls.__new__(cls, result) - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> LEB128: result = [] while True: - if (value := self._read(stream, context)) == 0: + if (value := cls._read(stream, context)) == 0: break result.append(value) return result - def _write(self, stream: BinaryIO, data: int) -> int: + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> int: # only write negative numbers when in signed mode - if data < 0 and not self.signed: + if data < 0 and not cls.signed: raise ValueError("Attempt to encode a negative integer using unsigned LEB128 encoding") result = bytearray() @@ -64,8 +60,8 @@ def _write(self, stream: BinaryIO, data: int) -> int: # function works similar for signed- and unsigned integers, except for the check when to stop # the encoding process. - if (self.signed and (data == 0 and byte & 0x40 == 0) or (data == -1 and byte & 0x40 != 0)) or ( - not self.signed and data == 0 + if (cls.signed and (data == 0 and byte & 0x40 == 0) or (data == -1 and byte & 0x40 != 0)) or ( + not cls.signed and data == 0 ): result.append(byte) break @@ -75,12 +71,3 @@ def _write(self, stream: BinaryIO, data: int) -> int: stream.write(result) return len(result) - - def _write_0(self, stream: BinaryIO, data: list[int]) -> int: - return self._write_array(stream, data + [0]) - - def default(self) -> int: - return 0 - - def default_array(self, count: int) -> list[int]: - return [0] * count diff --git a/dissect/cstruct/types/packed.py b/dissect/cstruct/types/packed.py new file mode 100644 index 0000000..ec42c23 --- /dev/null +++ b/dissect/cstruct/types/packed.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from functools import lru_cache +from struct import Struct +from typing import Any, BinaryIO + +from dissect.cstruct.types.base import EOF, BaseType + + +@lru_cache(1024) +def _struct(endian: str, packchar: str) -> Struct: + return Struct(f"{endian}{packchar}") + + +class Packed(BaseType): + """Packed type for Python struct (un)packing.""" + + packchar: str + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed: + return cls._read_array(stream, 1, context)[0] + + @classmethod + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> list[Packed]: + if count == EOF: + data = stream.read() + length = len(data) + count = length // cls.size + else: + length = cls.size * count + data = stream.read(length) + + fmt = _struct(cls.cs.endian, f"{count}{cls.packchar}") + + if len(data) != length: + raise EOFError(f"Read {len(data)} bytes, but expected {length}") + + return [cls.__new__(cls, value) for value in fmt.unpack(data)] + + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Packed: + result = [] + + fmt = _struct(cls.cs.endian, cls.packchar) + while True: + data = stream.read(cls.size) + + if len(data) != cls.size: + raise EOFError(f"Read {len(data)} bytes, but expected {cls.size}") + + if (value := fmt.unpack(data)[0]) == 0: + break + + result.append(cls.__new__(cls, value)) + + return result + + @classmethod + def _write(cls, stream: BinaryIO, data: Packed) -> int: + return stream.write(_struct(cls.cs.endian, cls.packchar).pack(data)) + + @classmethod + def _write_array(cls, stream: BinaryIO, data: list[Packed]) -> int: + return stream.write(_struct(cls.cs.endian, f"{len(data)}{cls.packchar}").pack(*data)) diff --git a/dissect/cstruct/types/packedtype.py b/dissect/cstruct/types/packedtype.py deleted file mode 100644 index fc42a95..0000000 --- a/dissect/cstruct/types/packedtype.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import struct -from typing import TYPE_CHECKING, Any, BinaryIO, List - -from dissect.cstruct.types import RawType - -if TYPE_CHECKING: - from dissect.cstruct import cstruct - - -class PackedType(RawType): - """Implements a packed type that uses Python struct packing characters.""" - - def __init__(self, cstruct: cstruct, name: str, size: int, packchar: str, alignment: int = None): - super().__init__(cstruct, name, size, alignment) - self.packchar = packchar - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> int: - return self._read_array(stream, 1, context)[0] - - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> List[int]: - length = self.size * count - data = stream.read(length) - fmt = self.cstruct.endian + str(count) + self.packchar - - if len(data) != length: - raise EOFError(f"Read {len(data)} bytes, but expected {length}") - - return list(struct.unpack(fmt, data)) - - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> List[int]: - byte_array = [] - while True: - bytes_stream = stream.read(self.size) - unpacked_struct = struct.unpack(self.cstruct.endian + self.packchar, bytes_stream)[0] - - if unpacked_struct == 0: - break - - byte_array.append(unpacked_struct) - - return byte_array - - def _write(self, stream: BinaryIO, data: int) -> int: - return self._write_array(stream, [data]) - - def _write_array(self, stream: BinaryIO, data: List[int]) -> int: - fmt = self.cstruct.endian + str(len(data)) + self.packchar - return stream.write(struct.pack(fmt, *data)) - - def _write_0(self, stream: BinaryIO, data: List[int]) -> int: - return self._write_array(stream, data + [0]) - - def default(self) -> int: - return 0 - - def default_array(self, count: int) -> List[int]: - return [0] * count diff --git a/dissect/cstruct/types/pointer.py b/dissect/cstruct/types/pointer.py index cc9312e..f79d86d 100644 --- a/dissect/cstruct/types/pointer.py +++ b/dissect/cstruct/types/pointer.py @@ -1,52 +1,30 @@ from __future__ import annotations -import operator -from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Dict, Union +from typing import Any, BinaryIO from dissect.cstruct.exceptions import NullPointerDereference -from dissect.cstruct.types import BaseType, CharType, RawType +from dissect.cstruct.types.base import BaseType, MetaType +from dissect.cstruct.types.char import Char +from dissect.cstruct.types.void import Void -if TYPE_CHECKING: - from dissect.cstruct import cstruct +class Pointer(int, BaseType): + """Pointer to some other type.""" -class Pointer(RawType): - """Implements a pointer to some other type.""" + type: MetaType + _stream: BinaryIO + _context: dict[str, Any] + _value: BaseType - def __init__(self, cstruct: cstruct, target: BaseType): - self.cstruct = cstruct - self.type = target - super().__init__(cstruct, "pointer", self.cstruct.pointer.size, self.cstruct.pointer.alignment) + def __new__(cls, value: int, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer: + obj = super().__new__(cls, value) + obj._stream = stream + obj._context = context + obj._value = None + return obj def __repr__(self) -> str: - return f"" - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> PointerInstance: - addr = self.cstruct.pointer(stream) - return PointerInstance(self.type, stream, addr, context) - - def _write(self, stream: BinaryIO, data: Union[int, PointerInstance]): - if isinstance(data, PointerInstance): - data = data._addr - - if not isinstance(data, int): - raise TypeError("Invalid pointer data") - - return self.cstruct.pointer._write(stream, data) - - -class PointerInstance: - """Like the Instance class, but for structures referenced by a pointer.""" - - def __init__(self, type_: BaseType, stream: BinaryIO, addr: int, ctx: Dict[str, Any]): - self._stream = stream - self._type = type_ - self._addr = addr - self._ctx = ctx - self._value = None - - def __repr__(self) -> str: - return f"" + return f"<{self.type.__name__}* @ {self:#x}>" def __str__(self) -> str: return str(self.dereference()) @@ -54,72 +32,62 @@ def __str__(self) -> str: def __getattr__(self, attr: str) -> Any: return getattr(self.dereference(), attr) - def __int__(self) -> int: - return self._addr - - def __nonzero__(self) -> bool: - return self._addr != 0 - - def __addr_math(self, other: Union[int, PointerInstance], op: Callable[[int, int], int]) -> PointerInstance: - if isinstance(other, PointerInstance): - other = other._addr - - return PointerInstance(self._type, self._stream, op(self._addr, other), self._ctx) - - def __add__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__add__) + def __add__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__add__(self, other), self._stream, self._context) - def __sub__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__sub__) + def __sub__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__sub__(self, other), self._stream, self._context) - def __mul__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__mul__) + def __mul__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__mul__(self, other), self._stream, self._context) - def __floordiv__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__floordiv__) + def __floordiv__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__floordiv__(self, other), self._stream, self._context) - def __mod__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__mod__) + def __mod__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__mod__(self, other), self._stream, self._context) - def __pow__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__pow__) + def __pow__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__pow__(self, other), self._stream, self._context) - def __lshift__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__lshift__) + def __lshift__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__lshift__(self, other), self._stream, self._context) - def __rshift__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__rshift__) + def __rshift__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__rshift__(self, other), self._stream, self._context) - def __and__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__and__) + def __and__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__and__(self, other), self._stream, self._context) - def __xor__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__xor__) + def __xor__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__xor__(self, other), self._stream, self._context) - def __or__(self, other: Union[int, PointerInstance]) -> PointerInstance: - return self.__addr_math(other, operator.__or__) + def __or__(self, other: int) -> Pointer: + return type.__call__(self.__class__, int.__or__(self, other), self._stream, self._context) - def __eq__(self, other: Union[int, PointerInstance]) -> bool: - if isinstance(other, PointerInstance): - other = other._addr + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Pointer: + return cls.__new__(cls, cls.cs.pointer._read(stream, context), stream, context) - return self._addr == other + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> int: + return cls.cs.pointer._write(stream, data) def dereference(self) -> Any: - if self._addr == 0: + if self == 0: raise NullPointerDereference() - if self._value is None: + if self._value is None and not issubclass(self.type, Void): # Read current position of file read/write pointer position = self._stream.tell() # Reposition the file read/write pointer - self._stream.seek(self._addr) + self._stream.seek(self) - if isinstance(self._type, CharType): + if issubclass(self.type, Char): # this makes the assumption that a char pointer is a null-terminated string - value = self._type._read_0(self._stream, self._ctx) + value = self.type._read_0(self._stream, self._context) else: - value = self._type._read(self._stream, self._ctx) + value = self.type._read(self._stream, self._context) self._stream.seek(position) self._value = value diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index ef8927b..3fb1fe6 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -1,69 +1,144 @@ from __future__ import annotations import io -from collections import OrderedDict -from typing import TYPE_CHECKING, Any, BinaryIO, List +from contextlib import contextmanager +from functools import lru_cache +from operator import attrgetter +from textwrap import dedent +from types import FunctionType +from typing import Any, BinaryIO, Callable, ContextManager from dissect.cstruct.bitbuffer import BitBuffer -from dissect.cstruct.types import BaseType, Enum, Instance - -if TYPE_CHECKING: - from dissect.cstruct import cstruct +from dissect.cstruct.types.base import BaseType, MetaType +from dissect.cstruct.types.enum import EnumMetaType +from dissect.cstruct.types.pointer import Pointer class Field: - """Holds a structure field.""" + """Structure field.""" - def __init__(self, name: str, type_: BaseType, bits: int = None, offset: int = None): + def __init__(self, name: str, type_: MetaType, bits: int = None, offset: int = None): self.name = name self.type = type_ self.bits = bits self.offset = offset - self.alignment = type_.alignment + self.alignment = type_.alignment or 1 - def __repr__(self): + def __repr__(self) -> str: bits_str = f" : {self.bits}" if self.bits else "" - return f"" - - -class Structure(BaseType): - """Type class for structures.""" - - def __init__( - self, cstruct: cstruct, name: str, fields: List[Field] = None, align: bool = False, anonymous: bool = False - ): - super().__init__(cstruct) - self.name = name - self.size = None - self.alignment = None - - self.lookup = OrderedDict() - self.fields = fields - - self.align = align - self.anonymous = anonymous - self.dynamic = False + return f"" + + +class StructureMetaType(MetaType): + """Base metaclass for cstruct structure type classes.""" + + # TODO: resolve field types in _update_fields, remove resolves elsewhere? + + fields: dict[str, Field] + """Mapping of field names to :class:`Field` objects, including "folded" fields from anonymous structures.""" + lookup: dict[str, Field] + """Mapping of "raw" field names to :class:`Field` objects. E.g. holds the anonymous struct and not its fields.""" + __fields__: list[Field] + """List of :class:`Field` objects for this structure. This is the structures' Single Source Of Truth.""" + + # Internal + __align__: bool + __anonymous__: bool + __updating__ = False + __compiled__ = False + + def __new__(metacls, name: str, bases: tuple[type, ...], classdict: dict[str, Any]) -> MetaType: + if (fields := classdict.pop("fields", None)) is not None: + metacls._update_fields(metacls, fields, align=classdict.get("align", False), classdict=classdict) + + return super().__new__(metacls, name, bases, classdict) + + def __call__(cls, *args, **kwargs) -> Structure: + if ( + cls.__fields__ + and len(args) == len(cls.__fields__) == 1 + and isinstance(args[0], bytes) + and issubclass(cls.__fields__[0].type, bytes) + and len(args[0]) == cls.__fields__[0].type.size + ): + # Shortcut for single char/bytes type + return type.__call__(cls, *args, **kwargs) + elif not args and not kwargs: + obj = cls(**{field.name: field.type.default() for field in cls.__fields__}) + object.__setattr__(obj, "_values", {}) + object.__setattr__(obj, "_sizes", {}) + return obj + + return super().__call__(*args, **kwargs) + + def _update_fields(cls, fields: list[Field], align: bool = False, classdict: dict[str, Any] | None = None) -> None: + classdict = classdict or {} + + lookup = {} + raw_lookup = {} + init_names = [] + field_names = [] + for field in fields: + if field.name in lookup and field.name != "_": + raise ValueError(f"Duplicate field name: {field.name}") + + if isinstance(field.type, StructureMetaType) and field.type.__anonymous__: + for anon_field in field.type.fields.values(): + attr = f"{field.name}.{anon_field.name}" + classdict[anon_field.name] = property(attrgetter(attr), attrsetter(attr)) + + lookup.update(field.type.fields) + else: + lookup[field.name] = field + + raw_lookup[field.name] = field + + num_fields = len(lookup) + field_names = lookup.keys() + init_names = raw_lookup.keys() + classdict["fields"] = lookup + classdict["lookup"] = raw_lookup + classdict["__fields__"] = fields + classdict["__bool__"] = _patch_attributes(_make__bool__(num_fields), field_names, 1) + + if issubclass(cls, UnionMetaType) or isinstance(cls, UnionMetaType): + classdict["__init__"] = _patch_setattr_args_and_attributes( + _make_setattr__init__(len(init_names)), init_names + ) + # Not a great way to do this but it works for now + classdict["__eq__"] = Union.__eq__ + else: + classdict["__init__"] = _patch_args_and_attributes(_make__init__(len(init_names)), init_names) + classdict["__eq__"] = _patch_attributes(_make__eq__(num_fields), field_names, 1) - for field in self.fields: - self.lookup[field.name] = field - if isinstance(field.type, Structure) and field.type.anonymous: - self.lookup.update(field.type.lookup) + # If we're calling this as a class method or a function on the metaclass + if issubclass(cls, type): + size, alignment = cls._calculate_size_and_offsets(cls, fields, align) + else: + size, alignment = cls._calculate_size_and_offsets(fields, align) - self._calc_size_and_offsets() + if cls.__compiled__: + # If the previous class was compiled try to compile this too + from dissect.cstruct import compiler - def __len__(self) -> int: - if self.dynamic: - raise TypeError("Dynamic size") + try: + classdict["_read"] = compiler.Compiler(cls.cs).compile_read(fields, cls.__name__, align=cls.__align__) + classdict["__compiled__"] = True + except Exception: + # Revert _read to the slower loop based method + classdict["_read"] = classmethod(StructureMetaType._read) + classdict["__compiled__"] = False - if self.size is None: - self._calc_size_and_offsets() + # TODO: compile _write + # TODO: generate cached_property for lazy reading - return self.size + classdict["size"] = size + classdict["alignment"] = alignment + classdict["dynamic"] = size is None - def __repr__(self) -> str: - return f"" + return classdict - def _calc_size_and_offsets(self) -> None: + def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: """Iterate all fields in this structure to calculate the field offsets and total structure size. If a structure has a dynamic field, further field offsets will be set to None and self.dynamic @@ -81,22 +156,22 @@ def _calc_size_and_offsets(self) -> None: # How many bits we have left in the current bit field bits_remaining = 0 - for field in self.fields: + for field in fields: if field.offset is not None: # If a field already has an offset, it's leading offset = field.offset - if self.align and offset is not None: + if align and offset is not None: # Round to next alignment offset += -offset & (field.alignment - 1) # The alignment of this struct is equal to its largest members' alignment - alignment = max(alignment, field.type.alignment) + alignment = max(alignment, field.alignment) if field.bits: field_type = field.type - if isinstance(field_type, Enum): + if isinstance(field_type, EnumMetaType): field_type = field_type.type # Bit fields have special logic @@ -138,43 +213,41 @@ def _calc_size_and_offsets(self) -> None: except TypeError: # This field is dynamic offset = None - self.dynamic = True continue offset += field_len - if self.align and offset is not None: + if align and offset is not None: # Add "tail padding" if we need to align # This bit magic rounds up to the next alignment boundary # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 offset += -offset & (alignment - 1) # The structure size is whatever the currently calculated offset is - self.size = offset - self.alignment = alignment + return offset, alignment - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> Instance: - bit_buffer = BitBuffer(stream, self.cstruct.endian) + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Structure: + bit_buffer = BitBuffer(stream, cls.cs.endian) struct_start = stream.tell() - result = OrderedDict() + result = {} sizes = {} - for field in self.fields: + for field in cls.__fields__: offset = stream.tell() - field_type = self.cstruct.resolve(field.type) + field_type = cls.cs.resolve(field.type) - if field.offset and offset != struct_start + field.offset: + if field.offset is not None and offset != struct_start + field.offset: # Field is at a specific offset, either alligned or added that way offset = struct_start + field.offset stream.seek(offset) - if self.align and field.offset is None: + if cls.__align__ and field.offset is None: # Previous field was dynamically sized and we need to align offset += -offset & (field.alignment - 1) stream.seek(offset) if field.bits: - if isinstance(field_type, Enum): + if isinstance(field_type, EnumMetaType): value = field_type(bit_buffer.read(field_type.type, field.bits)) else: value = bit_buffer.read(field_type, field.bits) @@ -187,27 +260,38 @@ def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> Instance: value = field_type._read(stream, result) - if isinstance(field_type, Structure) and field_type.anonymous: - sizes.update(value._sizes) - result.update(value._values) - else: - if field.name: - sizes[field.name] = stream.tell() - offset - result[field.name] = value + if field.name: + sizes[field.name] = stream.tell() - offset + result[field.name] = value - if self.align: + if cls.__align__: # Align the stream - stream.seek(-stream.tell() & (self.alignment - 1), io.SEEK_CUR) + stream.seek(-stream.tell() & (cls.alignment - 1), io.SEEK_CUR) + + obj = cls(**result) + obj._sizes = sizes + obj._values = result + return obj + + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> list[Structure]: + result = [] + + while obj := cls._read(stream, context): + result.append(obj) - return Instance(self, result, sizes) + return result - def _write(self, stream: BinaryIO, data: Instance) -> int: - bit_buffer = BitBuffer(stream, self.cstruct.endian) + def _write(cls, stream: BinaryIO, data: Structure) -> int: + bit_buffer = BitBuffer(stream, cls.cs.endian) struct_start = stream.tell() num = 0 - for field in self.fields: - bit_field_type = (field.type.type if isinstance(field.type, Enum) else field.type) if field.bits else None + for field in cls.__fields__: + field_type = cls.cs.resolve(field.type) + + bit_field_type = ( + (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None + ) # Current field is not a bit field, but previous was # Or, moved to a bit field of another type, e.g. uint16 f1 : 8, uint32 f2 : 8; if (not field.bits and bit_buffer._type is not None) or ( @@ -218,14 +302,14 @@ def _write(self, stream: BinaryIO, data: Instance) -> int: offset = stream.tell() - if field.offset and offset < struct_start + field.offset: + if field.offset is not None and offset < struct_start + field.offset: # Field is at a specific offset, either alligned or added that way stream.write(b"\x00" * (struct_start + field.offset - offset)) offset = struct_start + field.offset - if self.align and field.offset is None: + if cls.__align__ and field.offset is None: is_bitbuffer_boundary = bit_buffer._type and ( - bit_buffer._remaining == 0 or bit_buffer._type != field.type + bit_buffer._remaining == 0 or bit_buffer._type != field_type ) if not bit_buffer._type or is_bitbuffer_boundary: # Previous field was dynamically sized and we need to align @@ -235,142 +319,360 @@ def _write(self, stream: BinaryIO, data: Instance) -> int: value = getattr(data, field.name, None) if value is None: - value = field.type.default() + value = field_type() if field.bits: - if isinstance(field.type, Enum): - bit_buffer.write(field.type.type, value.value, field.bits) + if isinstance(field_type, EnumMetaType): + bit_buffer.write(field_type.type, value.value, field.bits) else: - bit_buffer.write(field.type, value, field.bits) + bit_buffer.write(field_type, value, field.bits) else: - if isinstance(field.type, Structure) and field.type.anonymous: - field.type._write(stream, data) - else: - field.type._write(stream, value) + field_type._write(stream, value) num += stream.tell() - offset if bit_buffer._type is not None: bit_buffer.flush() - if self.align: + if cls.__align__: # Align the stream - stream.write(b"\x00" * (-stream.tell() & (self.alignment - 1))) + stream.write(b"\x00" * (-stream.tell() & (cls.alignment - 1))) return num - def add_field(self, name: str, type_: BaseType, bits: int = None, offset: int = None) -> None: - """Add a field to this structure. - - Args: - name: The field name. - type_: The field type. - bits: The bit of the field. - offset: The field offset. - """ + def add_field(cls, name: str, type_: BaseType, bits: int | None = None, offset: int | None = None) -> None: field = Field(name, type_, bits=bits, offset=offset) - self.fields.append(field) - self.lookup[name] = field - if isinstance(field.type, Structure) and field.type.anonymous: - self.lookup.update(field.type.lookup) - self.size = None + cls.__fields__.append(field) - def default(self) -> Instance: - """Create and return an empty Instance from this structure. + if not cls.__updating__: + cls.commit() - Returns: - An empty Instance from this structure. - """ - result = OrderedDict() - for field in self.fields: - if isinstance(field.type, Structure) and field.type.anonymous: - result.update(field.type.default()._values) - else: - result[field.name] = field.type.default() + @contextmanager + def start_update(cls) -> ContextManager: + try: + cls.__updating__ = True + yield + finally: + cls.commit() + cls.__updating__ = False - return Instance(self, result) + def commit(cls) -> None: + classdict = cls._update_fields(cls.__fields__, cls.__align__) - def show(self, indent: int = 0) -> None: - """Pretty print this structure.""" - if indent == 0: - print(f"struct {self.name}") + for key, value in classdict.items(): + setattr(cls, key, value) - for field in self.fields: - if field.offset is None: - offset = "0x??" - else: - offset = f"0x{field.offset:02x}" - print(f"{' ' * indent}+{offset} {field.name} {field.type}") +class Structure(BaseType, metaclass=StructureMetaType): + """Base class for cstruct structure type classes.""" + + _values: dict[str, Any] + _sizes: dict[str, int] - if isinstance(field.type, Structure): - field.type.show(indent + 1) + def __len__(self) -> int: + return len(self.dumps()) + def __bytes__(self) -> bytes: + return self.dumps() -class Union(Structure): - """Type class for unions""" + def __getitem__(self, item: str) -> Any: + return getattr(self, item) def __repr__(self) -> str: - return f"" + values = [ + f"{k}={hex(self[k]) if (issubclass(f.type, int) and not issubclass(f.type, Pointer)) else repr(self[k])}" + for k, f in self.__class__.fields.items() + ] + return f"<{self.__class__.__name__} {' '.join(values)}>" + - def _calc_size_and_offsets(self) -> None: +class UnionMetaType(StructureMetaType): + """Base metaclass for cstruct union type classes.""" + + def __call__(cls, *args, **kwargs) -> Union: + obj = super().__call__(*args, **kwargs) + if kwargs: + # Calling with kwargs means we are initializing with values + # Proxify all values + obj._proxify() + return obj + + def _calculate_size_and_offsets(cls, fields: list[Field], align: bool = False) -> tuple[int | None, int]: size = 0 alignment = 0 - for field in self.fields: - if field.alignment is None: - # If a field already has an alignment, it's leading - field.alignment = field.type.alignment + for field in fields: + if size is not None: + try: + size = max(len(field.type), size) + except TypeError: + size = None - size = max(len(field.type), size) alignment = max(field.alignment, alignment) - if self.align and size is not None: + if align and size is not None: # Add "tail padding" if we need to align # This bit magic rounds up to the next alignment boundary # E.g. offset = 3; alignment = 8; -offset & (alignment - 1) = 5 size += -size & (alignment - 1) - self.size = size - self.alignment = alignment + return size, alignment - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> Instance: - buf = io.BytesIO(memoryview(stream.read(len(self)))) - result = OrderedDict() + def _read_fields(cls, stream: BinaryIO, context: dict[str, Any] = None) -> tuple[dict[str, Any], dict[str, int]]: + result = {} sizes = {} - for field in self.fields: - start = 0 - buf.seek(0) - field_type = self.cstruct.resolve(field.type) + if cls.size is None: + offset = stream.tell() + buf = stream + else: + offset = 0 + buf = io.BytesIO(stream.read(cls.size)) + + for field in cls.__fields__: + field_type = cls.cs.resolve(field.type) - if field.offset: - buf.seek(field.offset) + start = 0 + if field.offset is not None: start = field.offset - v = field_type._read(buf, result) + buf.seek(offset + start) + value = field_type._read(buf, result) - if isinstance(field_type, Structure) and field_type.anonymous: - sizes.update(v._sizes) - result.update(v._values) - else: - sizes[field.name] = buf.tell() - start - result[field.name] = v + sizes[field.name] = buf.tell() - start + result[field.name] = value - return Instance(self, result, sizes) + return result, sizes - def _write(self, stream: BinaryIO, data: Instance) -> Instance: + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Union: + if cls.size is None: + start = stream.tell() + result, sizes = cls._read_fields(stream, context) + size = stream.tell() - start + stream.seek(start) + buf = stream.read(size) + else: + result = {} + sizes = {} + buf = stream.read(cls.size) + + obj: Union = cls(**result) + object.__setattr__(obj, "_values", result) + object.__setattr__(obj, "_sizes", sizes) + object.__setattr__(obj, "_buf", buf) + + if cls.size is not None: + obj._update() + + return obj + + def _write(cls, stream: BinaryIO, data: Union) -> int: offset = stream.tell() + expected_offset = offset + len(cls) - # Find the largest field - field = max(self.fields, key=lambda e: len(e.type)) + # Sort by largest field + fields = sorted(cls.__fields__, key=lambda e: len(e.type), reverse=True) + anonymous_struct = False - # Write the value to the stream using the largest field type - if isinstance(field.type, Structure) and field.type.anonymous: - field.type._write(stream, data) - else: - field.type._write(stream, getattr(data, field.name)) + # Try to write by largest field + for field in fields: + if isinstance(field.type, StructureMetaType) and field.type.__anonymous__: + # Prefer to write regular fields initially + anonymous_struct = field.type + continue + + # Skip empty values + if (value := getattr(data, field.name)) is None: + continue + + # We have a value, write it + field.type._write(stream, value) + break + + # If we haven't written anything yet and we initially skipped an anonymous struct, write it now + if stream.tell() == offset and anonymous_struct: + anonymous_struct._write(stream, data) + + # If we haven't filled the union size yet, pad it + if remaining := expected_offset - stream.tell(): + stream.write(b"\x00" * remaining) return stream.tell() - offset - def show(self, indent: int = 0) -> None: - raise NotImplementedError() + +class Union(Structure, metaclass=UnionMetaType): + """Base class for cstruct union type classes.""" + + _buf: bytes + + def __eq__(self, other: Any) -> bool: + return self.__class__ is other.__class__ and bytes(self) == bytes(other) + + def __setattr__(self, attr: str, value: Any) -> None: + if self.__class__.dynamic: + raise NotImplementedError("Modifying a dynamic union is not yet supported") + + super().__setattr__(attr, value) + self._rebuild(attr) + + def _rebuild(self, attr: str) -> None: + if (cur_buf := getattr(self, "_buf", None)) is None: + cur_buf = b"\x00" * self.__class__.size + + buf = io.BytesIO(cur_buf) + field = self.__class__.fields[attr] + if field.offset: + buf.seek(field.offset) + field.type._write(buf, getattr(self, attr)) + + object.__setattr__(self, "_buf", buf.getvalue()) + self._update() + + def _update(self) -> None: + result, sizes = self.__class__._read_fields(io.BytesIO(self._buf)) + self.__dict__.update(result) + object.__setattr__(self, "_values", result) + object.__setattr__(self, "_sizes", sizes) + + def _proxify(self) -> None: + def _proxy_structure(value: Structure) -> None: + for field in value.__class__.__fields__: + if issubclass(field.type, Structure): + nested_value = getattr(value, field.name) + proxy = UnionProxy(self, field.name, nested_value) + object.__setattr__(value, field.name, proxy) + _proxy_structure(nested_value) + + _proxy_structure(self) + + +class UnionProxy: + __union__: Union + __attr__: str + __target__: Structure + + def __init__(self, union: Union, attr: str, target: Structure): + object.__setattr__(self, "__union__", union) + object.__setattr__(self, "__attr__", attr) + object.__setattr__(self, "__target__", target) + + def __len__(self) -> int: + return len(self.__target__.dumps()) + + def __bytes__(self) -> bytes: + return self.__target__.dumps() + + def __getitem__(self, item: str) -> Any: + return getattr(self.__target__, item) + + def __repr__(self) -> str: + return repr(self.__target__) + + def __getattr__(self, attr: str) -> Any: + return getattr(self.__target__, attr) + + def __setattr__(self, attr: str, value: Any) -> None: + setattr(self.__target__, attr, value) + self.__union__._rebuild(self.__attr__) + + +def attrsetter(path: str) -> Callable[[Any], Any]: + path, _, attr = path.rpartition(".") + path = path.split(".") + + def _func(obj: Any, value: Any) -> Any: + for name in path: + obj = getattr(obj, name) + setattr(obj, attr, value) + + return _func + + +def _codegen(func: FunctionType) -> FunctionType: + # Inspired by https://github.com/dabeaz/dataklasses + @lru_cache + def make_func_code(num_fields: int) -> FunctionType: + names = [f"_{n}" for n in range(num_fields)] + exec(func(names), {}, d := {}) + return d.popitem()[1] + + return make_func_code + + +def _patch_args_and_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: + return type(func)( + func.__code__.replace( + co_names=(*func.__code__.co_names[:start], *fields), + co_varnames=("self", *fields), + ), + func.__globals__, + argdefs=func.__defaults__, + ) + + +def _patch_setattr_args_and_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: + return type(func)( + func.__code__.replace( + co_consts=(None, *fields), + co_varnames=("self", *fields), + ), + func.__globals__, + argdefs=func.__defaults__, + ) + + +def _patch_attributes(func: FunctionType, fields: list[str], start: int = 0) -> FunctionType: + return type(func)( + func.__code__.replace(co_names=(*func.__code__.co_names[:start], *fields)), + func.__globals__, + ) + + +@_codegen +def _make__init__(fields: list[str]) -> str: + field_args = ", ".join(f"{field} = None" for field in fields) + field_init = "\n".join(f" self.{name} = {name}" for name in fields) + + code = f"def __init__(self{', ' + field_args if field_args else ''}):\n" + return code + (field_init or " pass") + + +@_codegen +def _make_setattr__init__(fields: list[str]) -> str: + field_args = ", ".join(f"{field} = None" for field in fields) + field_init = "\n".join(f" object.__setattr__(self, {name!r}, {name})" for name in fields) + + code = f"def __init__(self{', ' + field_args if field_args else ''}):\n" + return code + (field_init or " pass") + + +@_codegen +def _make__eq__(fields: list[str]) -> str: + self_vals = ",".join(f"self.{name}" for name in fields) + other_vals = ",".join(f"other.{name}" for name in fields) + + if self_vals: + self_vals += "," + if other_vals: + other_vals += "," + + # In the future this could be a looser check, e.g. an __eq__ on the classes, which compares the fields + code = f""" + def __eq__(self, other): + if self.__class__ is other.__class__: + return ({self_vals}) == ({other_vals}) + return False + """ + + return dedent(code) + + +@_codegen +def _make__bool__(fields: list[str]) -> str: + vals = ", ".join(f"self.{name}" for name in fields) + + code = f""" + def __bool__(self): + return any([{vals}]) + """ + + return dedent(code) diff --git a/dissect/cstruct/types/void.py b/dissect/cstruct/types/void.py new file mode 100644 index 0000000..09d5d8b --- /dev/null +++ b/dissect/cstruct/types/void.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Any, BinaryIO + +from dissect.cstruct.types.base import BaseType + + +class Void(BaseType): + """Void type.""" + + def __bool__(self) -> bool: + return False + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Void: + return cls.__new__(cls) + + @classmethod + def _write(cls, stream: BinaryIO, data: Void) -> int: + return 0 diff --git a/dissect/cstruct/types/voidtype.py b/dissect/cstruct/types/voidtype.py deleted file mode 100644 index 89f6688..0000000 --- a/dissect/cstruct/types/voidtype.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, BinaryIO - -from dissect.cstruct.types import RawType - - -class VoidType(RawType): - """Implements a void type.""" - - def __init__(self): - super().__init__(None, "void") - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> None: - return None diff --git a/dissect/cstruct/types/wchar.py b/dissect/cstruct/types/wchar.py new file mode 100644 index 0000000..8799b8b --- /dev/null +++ b/dissect/cstruct/types/wchar.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import sys +from typing import Any, BinaryIO + +from dissect.cstruct.types.base import EOF, ArrayMetaType, BaseType + + +class WcharArray(str, BaseType, metaclass=ArrayMetaType): + """Wide-character array type for reading and writing UTF-16 strings.""" + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> WcharArray: + return type.__call__(cls, ArrayMetaType._read(cls, stream, context)) + + @classmethod + def _write(cls, stream: BinaryIO, data: str) -> int: + if cls.null_terminated: + data += "\x00" + return stream.write(data.encode(Wchar.__encoding_map__[cls.cs.endian])) + + @classmethod + def default(cls) -> WcharArray: + return type.__call__(cls, "\x00" * (0 if cls.dynamic or cls.null_terminated else cls.num_entries)) + + +class Wchar(str, BaseType): + """Wide-character type for reading and writing UTF-16 characters.""" + + ArrayType = WcharArray + + __encoding_map__ = { + "@": f"utf-16-{sys.byteorder[0]}e", + "=": f"utf-16-{sys.byteorder[0]}e", + "<": "utf-16-le", + ">": "utf-16-be", + "!": "utf-16-be", + } + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: + return cls._read_array(stream, 1, context) + + @classmethod + def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Wchar: + if count == 0: + return type.__call__(cls, "") + + if count != EOF: + count *= 2 + + data = stream.read(-1 if count == EOF else count) + if count != EOF and len(data) != count: + raise EOFError(f"Read {len(data)} bytes, but expected {count}") + + return type.__call__(cls, data.decode(cls.__encoding_map__[cls.cs.endian])) + + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: + buf = [] + while True: + point = stream.read(2) + if (bytes_read := len(point)) != 2: + raise EOFError(f"Read {bytes_read} bytes, but expected 2") + + if point == b"\x00\x00": + break + + buf.append(point) + + return type.__call__(cls, b"".join(buf).decode(cls.__encoding_map__[cls.cs.endian])) + + @classmethod + def _write(cls, stream: BinaryIO, data: str) -> int: + return stream.write(data.encode(cls.__encoding_map__[cls.cs.endian])) + + @classmethod + def default(cls) -> Wchar: + return type.__call__(cls, "\x00") diff --git a/dissect/cstruct/types/wchartype.py b/dissect/cstruct/types/wchartype.py deleted file mode 100644 index 47e8230..0000000 --- a/dissect/cstruct/types/wchartype.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import Any, BinaryIO - -from dissect.cstruct.types import RawType - - -class WcharType(RawType): - """Implements a wide-character type.""" - - def __init__(self, cstruct): - super().__init__(cstruct, "wchar", size=2, alignment=2) - - @property - def encoding(self) -> str: - if self.cstruct.endian == "<": # little-endian (LE) - return "utf-16-le" - elif self.cstruct.endian == ">": # big-endian (BE) - return "utf-16-be" - - def _read(self, stream: BinaryIO, context: dict[str, Any] = None) -> str: - return stream.read(2).decode(self.encoding) - - def _read_array(self, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> str: - if count == 0: - return "" - - data = stream.read(2 * count) - return data.decode(self.encoding) - - def _read_0(self, stream: BinaryIO, context: dict[str, Any] = None) -> str: - byte_string = b"" - while True: - bytes_stream = stream.read(2) - - if len(bytes_stream) != 2: - raise EOFError() - - if bytes_stream == b"\x00\x00": - break - - byte_string += bytes_stream - - return byte_string.decode(self.encoding) - - def _write(self, stream: BinaryIO, data: str) -> int: - return stream.write(data.encode(self.encoding)) - - def _write_array(self, stream: BinaryIO, data: str) -> int: - return self._write(stream, data) - - def _write_0(self, stream: BinaryIO, data: str) -> int: - return self._write(stream, data + "\x00") - - def default(self) -> str: - return "\x00" - - def default_array(self, count: int) -> str: - return "\x00" * count diff --git a/dissect/cstruct/utils.py b/dissect/cstruct/utils.py index 05b2acc..e01b628 100644 --- a/dissect/cstruct/utils.py +++ b/dissect/cstruct/utils.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import pprint import string -from typing import List, Tuple +import sys +from typing import Iterator -from dissect.cstruct.types import Instance, Structure +from dissect.cstruct.types.structure import Structure COLOR_RED = "\033[1;31m" COLOR_GREEN = "\033[1;32m" @@ -24,16 +27,18 @@ PRINTABLE = string.digits + string.ascii_letters + string.punctuation + " " ENDIANNESS_MAP = { - "network": "big", + "@": sys.byteorder, + "=": sys.byteorder, "<": "little", ">": "big", "!": "big", + "network": "big", } -Palette = List[Tuple[str, str]] +Palette = list[tuple[str, str]] -def _hexdump(data: bytes, palette: Palette = None, offset: int = 0, prefix: str = ""): +def _hexdump(data: bytes, palette: Palette | None = None, offset: int = 0, prefix: str = "") -> Iterator[str]: """Hexdump some data. Args: @@ -100,7 +105,9 @@ def _hexdump(data: bytes, palette: Palette = None, offset: int = 0, prefix: str yield f"{prefix}{offset + i:08x} {values:48s} {chars}" -def hexdump(data: bytes, palette=None, offset: int = 0, prefix: str = "", output: str = "print"): +def hexdump( + data: bytes, palette: Palette | None = None, offset: int = 0, prefix: str = "", output: str = "print" +) -> Iterator[str] | str | None: """Hexdump some data. Args: @@ -123,12 +130,11 @@ def hexdump(data: bytes, palette=None, offset: int = 0, prefix: str = "", output def _dumpstruct( structure: Structure, - instance: Instance, data: bytes, offset: int, color: bool, output: str, -): +) -> str | None: palette = [] colors = [ (COLOR_RED, COLOR_BG_RED), @@ -140,18 +146,18 @@ def _dumpstruct( (COLOR_WHITE, COLOR_BG_WHITE), ] ci = 0 - out = [f"struct {structure.name}:"] + out = [f"struct {structure.__class__.__name__}:"] foreground, background = None, None - for field in instance._type.lookup.values(): + for field in structure.__class__.__fields__: if getattr(field.type, "anonymous", False): continue if color: foreground, background = colors[ci % len(colors)] - palette.append((instance._size(field.name), background)) + palette.append((structure._sizes[field.name], background)) ci += 1 - value = getattr(instance, field.name) + value = getattr(structure, field.name) if isinstance(value, str): value = repr(value) elif isinstance(value, int): @@ -177,24 +183,30 @@ def _dumpstruct( return "\n".join(["", hexdump(data, palette, offset=offset, output="string"), "", out]) -def dumpstruct(obj, data: bytes = None, offset: int = 0, color: bool = True, output: str = "print"): +def dumpstruct( + obj: Structure | type[Structure], + data: bytes | None = None, + offset: int = 0, + color: bool = True, + output: str = "print", +) -> str | None: """Dump a structure or parsed structure instance. Prints a colorized hexdump and parsed structure output. Args: - obj: Structure or Instance to dump. - data: Bytes to parse the Structure on, if obj is not a parsed Instance. + obj: Structure to dump. + data: Bytes to parse the Structure on, if obj is not a parsed Structure already. offset: Byte offset of the hexdump. output: Output format, can be 'print' or 'string'. """ if output not in ("print", "string"): raise ValueError(f"Invalid output argument: {output!r} (should be 'print' or 'string').") - if isinstance(obj, Instance): - return _dumpstruct(obj._type, obj, obj.dumps(), offset, color, output) - elif isinstance(obj, Structure) and data: - return _dumpstruct(obj, obj(data), data, offset, color, output) + if isinstance(obj, Structure): + return _dumpstruct(obj, obj.dumps(), offset, color, output) + elif issubclass(obj, Structure) and data: + return _dumpstruct(obj(data), data, offset, color, output) else: raise ValueError("Invalid arguments") diff --git a/examples/protobuf.py b/examples/protobuf.py new file mode 100644 index 0000000..84a09cd --- /dev/null +++ b/examples/protobuf.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any, BinaryIO + +from dissect.cstruct import cstruct +from dissect.cstruct.types import BaseType + + +class ProtobufVarint(BaseType): + """Implements a protobuf integer type for dissect.cstruct that can span a variable amount of bytes. + + Mainly follows the BaseType implementation with minor tweaks + to support protobuf's msb varint implementation. + + Resources: + - https://protobuf.dev/programming-guides/encoding/ + - https://github.com/protocolbuffers/protobuf/blob/main/python/google/protobuf/internal/decoder.py + """ + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> int: + return decode_varint(stream) + + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> int: + return stream.write(encode_varint(data)) + + +def decode_varint(stream: BinaryIO) -> int: + """Reads a varint from the provided buffer stream. + + If we have not reached the end of a varint, the msb will be 1. + We read every byte from our current position until the msb is 0. + """ + result = 0 + i = 0 + while True: + byte = stream.read(1) + result |= (byte[0] & 0x7F) << (i * 7) + i += 1 + if byte[0] & 0x80 == 0: + break + + return result + + +def encode_varint(number: int) -> bytes: + """Encode a decoded protobuf varint to its original bytes.""" + buf = [] + while True: + towrite = number & 0x7F + number >>= 7 + if number: + buf.append(towrite | 0x80) + else: + buf.append(towrite) + break + return bytes(buf) + + +if __name__ == "__main__": + cdef = """ + struct foo { + uint32 foo; + varint size; + char bar[size]; + }; + """ + + cs = cstruct(endian=">") + cs.add_custom_type("varint", ProtobufVarint) + cs.load(cdef, compiled=False) + + aaa = b"a" * 123456 + buf = b"\x00\x00\x00\x01\xc0\xc4\x07" + aaa + foo = cs.foo(buf + b"\x01\x02\x03") + assert foo.foo == 1 + assert foo.size == 123456 + assert foo.bar == aaa + assert foo.dumps() == buf + + assert cs.varint[2](b"\x80\x01\x80\x02") == [128, 256] + assert cs.varint[2].dumps([128, 256]) == b"\x80\x01\x80\x02" diff --git a/examples/secdesc.py b/examples/secdesc.py index 5c5e83f..3f610d9 100644 --- a/examples/secdesc.py +++ b/examples/secdesc.py @@ -52,8 +52,8 @@ struct ACCESS_ALLOWED_OBJECT_ACE { uint32 Mask; uint32 Flags; - char ObjectType[Flags & 1 * 16]; - char InheritedObjectType[Flags & 2 * 8]; + char ObjectType[(Flags & 1) * 16]; + char InheritedObjectType[(Flags & 2) * 8]; LDAP_SID Sid; }; """ @@ -97,12 +97,9 @@ def __init__(self, fh=None, in_obj=None): self.ldap_sid = in_obj def __repr__(self): - return "S-{}-{}-{}".format( - self.ldap_sid.Revision, - bytearray(self.ldap_sid.IdentifierAuthority.Value)[5], - "-".join(["{}".format(v) for v in self.ldap_sid.SubAuthority]), - ) - + authority = bytearray(self.ldap_sid.IdentifierAuthority.Value)[5] + sub_authority = "-".join(f"{v}" for v in self.ldap_sid.SubAuthority) + return f"S-{self.ldap_sid.Revision}-{authority}-{sub_authority}" class ACL: def __init__(self, fh): diff --git a/tests/conftest.py b/tests/conftest.py index 3f57405..40e2410 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,12 @@ import pytest +from dissect.cstruct.cstruct import cstruct + + +@pytest.fixture +def cs() -> cstruct: + return cstruct() + @pytest.fixture(params=[True, False]) def compiled(request): diff --git a/tests/test_align.py b/tests/test_align.py index 859a879..820a7ec 100644 --- a/tests/test_align.py +++ b/tests/test_align.py @@ -1,10 +1,11 @@ from io import BytesIO -from dissect import cstruct +from dissect.cstruct import cstruct +from tests.utils import verify_compiled -def test_align_struct(): - d = """ +def test_align_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint32 a; // 0x00 uint64 b; // 0x08 @@ -14,13 +15,14 @@ def test_align_struct(): uint16 f; // 0x1a }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) - fields = c.test.fields - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 32 + assert verify_compiled(cs.test, compiled) + + fields = cs.test.__fields__ + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 32 assert fields[0].offset == 0x00 assert fields[1].offset == 0x08 assert fields[2].offset == 0x10 @@ -35,28 +37,27 @@ def test_align_struct(): buf = bytes.fromhex(buf) fh = BytesIO(buf) - obj = c.test(fh) + obj = cs.test(fh) assert fh.tell() == 32 for name, value in obj._values.items(): - assert c.test.lookup[name].offset == value + assert cs.test.fields[name].offset == value assert obj.dumps() == buf -def test_align_union(): - d = """ +def test_align_union(cs: cstruct) -> None: + cdef = """ union test { uint32 a; uint64 b; }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, align=True) - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 8 + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 8 buf = """ 00 00 00 01 00 00 00 02 @@ -64,7 +65,7 @@ def test_align_union(): buf = bytes.fromhex(buf) fh = BytesIO(buf) - obj = c.test(fh) + obj = cs.test(fh) assert fh.tell() == 8 assert obj.a == 0x01000000 assert obj.b == 0x0200000001000000 @@ -72,23 +73,22 @@ def test_align_union(): assert obj.dumps() == buf -def test_align_union_tail(): - d = """ +def test_align_union_tail(cs: cstruct) -> None: + cdef = """ union test { uint64 a; uint32 b[3]; }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, align=True) - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 16 + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 16 -def test_align_array(): - d = """ +def test_align_array(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint32 a; // 0x00 uint64 b[4]; // 0x08 @@ -97,13 +97,14 @@ def test_align_array(): uint64 e; // 0x38 }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) + + assert verify_compiled(cs.test, compiled) - fields = c.test.fields - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 64 + fields = cs.test.__fields__ + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 64 assert fields[0].offset == 0x00 assert fields[1].offset == 0x08 assert fields[2].offset == 0x28 @@ -118,7 +119,7 @@ def test_align_array(): """ buf = bytes.fromhex(buf) - obj = c.test(buf) + obj = cs.test(buf) assert obj.a == 0x00 assert obj.b == [0x08, 0x10, 0x18, 0x20] @@ -129,8 +130,8 @@ def test_align_array(): assert obj.dumps() == buf -def test_align_struct_array(): - d = """ +def test_align_struct_array(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint32 a; // 0x00 uint64 b; // 0x08 @@ -140,13 +141,15 @@ def test_align_struct_array(): test a[4]; }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) + + assert verify_compiled(cs.test, compiled) + assert verify_compiled(cs.array, compiled) - fields = c.test.fields - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 16 + fields = cs.test.__fields__ + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 16 assert fields[0].offset == 0x00 assert fields[1].offset == 0x08 @@ -158,7 +161,7 @@ def test_align_struct_array(): """ buf = bytes.fromhex(buf) - obj = c.array(buf) + obj = cs.array(buf) assert obj.a[0].a == 0x00 assert obj.a[0].b == 0x08 @@ -172,8 +175,8 @@ def test_align_struct_array(): assert obj.dumps() == buf -def test_align_dynamic(): - d = """ +def test_align_dynamic(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint8 a; // 0x00 (value is 6 in test case) uint16 b[a]; // 0x02 @@ -184,10 +187,11 @@ def test_align_dynamic(): uint64 g; // 0x?? (0x30 in test case) }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) - fields = c.test.fields + assert verify_compiled(cs.test, compiled) + + fields = cs.test.__fields__ assert fields[0].offset == 0 assert fields[1].offset == 2 assert fields[2].offset is None @@ -203,7 +207,7 @@ def test_align_dynamic(): 30 00 00 00 00 00 00 00 """ buf = bytes.fromhex(buf) - obj = c.test(buf) + obj = cs.test(buf) assert obj.a == 0x06 assert obj.b == [0x02, 0x04, 0x06, 0x08, 0x0A, 0x0C] @@ -216,8 +220,8 @@ def test_align_dynamic(): assert obj.dumps() == buf -def test_align_nested_struct(): - d = """ +def test_align_nested_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint32 a; // 0x00 struct { @@ -227,10 +231,11 @@ def test_align_nested_struct(): uint64 d; // 0x18 }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) + + assert verify_compiled(cs.test, compiled) - fields = c.test.fields + fields = cs.test.__fields__ assert fields[0].offset == 0x00 assert fields[1].offset == 0x08 assert fields[2].offset == 0x18 @@ -240,7 +245,7 @@ def test_align_nested_struct(): 10 00 00 00 00 00 00 00 18 00 00 00 00 00 00 00 """ buf = bytes.fromhex(buf) - obj = c.test(buf) + obj = cs.test(buf) assert obj.a == 0x00 assert obj.nested.b == 0x08 @@ -250,8 +255,8 @@ def test_align_nested_struct(): assert obj.dumps() == buf -def test_align_bitfield(): - d = """ +def test_align_bitfield(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint16 a:4; // 0x00 uint16 b:4; @@ -262,10 +267,11 @@ def test_align_bitfield(): uint64 g; // 0x18 }; """ - c = cstruct.cstruct() - c.load(d, align=True) + cs.load(cdef, compiled=compiled, align=True) + + assert verify_compiled(cs.test, compiled) - fields = c.test.fields + fields = cs.test.__fields__ assert fields[0].offset == 0x00 assert fields[1].offset is None assert fields[2].offset == 0x08 @@ -279,7 +285,7 @@ def test_align_bitfield(): 10 00 00 00 02 00 00 00 18 00 00 00 00 00 00 00 """ buf = bytes.fromhex(buf) - obj = c.test(buf) + obj = cs.test(buf) assert obj.a == 0b10 assert obj.b == 0b01 @@ -292,8 +298,8 @@ def test_align_bitfield(): assert obj.dumps() == buf -def test_align_pointer(): - d = """ +def test_align_pointer(cs: cstruct, compiled: bool) -> None: + cdef = """ struct test { uint32 a; uint32 *b; @@ -301,15 +307,15 @@ def test_align_pointer(): uint16 d; }; """ - c = cstruct.cstruct(pointer="uint64") - c.load(d, align=True) + cs.pointer = cs.uint64 + cs.load(cdef, compiled=compiled, align=True) - assert c.pointer is c.uint64 + assert verify_compiled(cs.test, compiled) - fields = c.test.fields - assert c.test.align - assert c.test.alignment == 8 - assert c.test.size == 24 + fields = cs.test.__fields__ + assert cs.test.__align__ + assert cs.test.alignment == 8 + assert cs.test.size == 24 assert fields[0].offset == 0x00 assert fields[1].offset == 0x08 assert fields[2].offset == 0x10 @@ -320,7 +326,7 @@ def test_align_pointer(): 10 00 12 00 00 00 00 00 18 00 00 00 """ buf = bytes.fromhex(buf) - obj = c.test(buf) + obj = cs.test(buf) assert obj.a == 0x00 assert obj.b.dereference() == 0x18 diff --git a/tests/test_basic.py b/tests/test_basic.py index 43ad44e..fd3dfca 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,109 +1,64 @@ +from __future__ import annotations + import os +from io import BytesIO +from typing import BinaryIO import pytest -from dissect import cstruct +from dissect.cstruct.cstruct import cstruct from dissect.cstruct.exceptions import ArraySizeError, ParserError, ResolveError -from dissect.cstruct.types import Array, Pointer +from dissect.cstruct.types import BaseType from .utils import verify_compiled -def test_simple_types(): - cs = cstruct.cstruct() - assert cs.uint32(b"\x01\x00\x00\x00") == 1 - assert cs.uint32[10](b"A" * 20 + b"B" * 20) == [0x41414141] * 5 + [0x42424242] * 5 - assert cs.uint32[None](b"A" * 20 + b"\x00" * 4) == [0x41414141] * 5 - - with pytest.raises(EOFError): - cs.char[None](b"aaa") - - with pytest.raises(EOFError): - cs.wchar[None](b"a\x00a\x00a") - - -def test_write(): - cs = cstruct.cstruct() - - assert cs.uint32.dumps(1) == b"\x01\x00\x00\x00" - assert cs.uint16.dumps(255) == b"\xff\x00" - assert cs.int8.dumps(-10) == b"\xf6" - assert cs.uint8[4].dumps([1, 2, 3, 4]) == b"\x01\x02\x03\x04" - assert cs.uint24.dumps(300) == b"\x2c\x01\x00" - assert cs.int24.dumps(-1337) == b"\xc7\xfa\xff" - assert cs.uint24[4].dumps([1, 2, 3, 4]) == b"\x01\x00\x00\x02\x00\x00\x03\x00\x00\x04\x00\x00" - assert cs.uint24[None].dumps([1, 2]) == b"\x01\x00\x00\x02\x00\x00\x00\x00\x00" - assert cs.char.dumps(0x61) == b"a" - assert cs.wchar.dumps("lala") == b"l\x00a\x00l\x00a\x00" - assert cs.uint32[None].dumps([1]) == b"\x01\x00\x00\x00\x00\x00\x00\x00" - - -def test_write_be(): - cs = cstruct.cstruct(endian=">") - - assert cs.uint32.dumps(1) == b"\x00\x00\x00\x01" - assert cs.uint16.dumps(255) == b"\x00\xff" - assert cs.int8.dumps(-10) == b"\xf6" - assert cs.uint8[4].dumps([1, 2, 3, 4]) == b"\x01\x02\x03\x04" - assert cs.uint24.dumps(300) == b"\x00\x01\x2c" - assert cs.int24.dumps(-1337) == b"\xff\xfa\xc7" - assert cs.uint24[4].dumps([1, 2, 3, 4]) == b"\x00\x00\x01\x00\x00\x02\x00\x00\x03\x00\x00\x04" - assert cs.char.dumps(0x61) == b"a" - assert cs.wchar.dumps("lala") == b"\x00l\x00a\x00l\x00a" - - -def test_duplicate_type(compiled): +def test_duplicate_type(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint32 a; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) with pytest.raises(ValueError): cs.load(cdef) -def test_load_file(compiled): +def test_load_file(cs: cstruct, compiled: bool) -> None: path = os.path.join(os.path.dirname(__file__), "data/testdef.txt") - cs = cstruct.cstruct() cs.loadfile(path, compiled=compiled) assert "test" in cs.typedefs -def test_read_type_name(): - cs = cstruct.cstruct() +def test_read_type_name(cs: cstruct) -> None: cs.read("uint32", b"\x01\x00\x00\x00") == 1 -def test_type_resolve(): - cs = cstruct.cstruct() - +def test_type_resolve(cs: cstruct) -> None: assert cs.resolve("BYTE") == cs.uint8 - with pytest.raises(cstruct.ResolveError) as excinfo: + with pytest.raises(ResolveError) as excinfo: cs.resolve("fake") assert "Unknown type" in str(excinfo.value) - cs.addtype("ref0", "uint32") + cs.add_type("ref0", "uint32") for i in range(1, 15): # Recursion limit is currently 10 - cs.addtype(f"ref{i}", f"ref{i - 1}") + cs.add_type(f"ref{i}", f"ref{i - 1}") - with pytest.raises(cstruct.ResolveError) as excinfo: + with pytest.raises(ResolveError) as excinfo: cs.resolve("ref14") assert "Recursion limit exceeded" in str(excinfo.value) -def test_constants(): +def test_constants(cs: cstruct) -> None: cdef = """ #define a 1 #define b 0x2 #define c "test" #define d 1 << 1 """ - cs = cstruct.cstruct() cs.load(cdef) assert cs.a == 1 @@ -112,8 +67,7 @@ def test_constants(): assert cs.d == 2 -def test_duplicate_types(): - cs = cstruct.cstruct() +def test_duplicate_types(cs: cstruct) -> None: cdef = """ struct A { uint32 a; @@ -137,29 +91,27 @@ def test_duplicate_types(): assert "Duplicate type" in str(excinfo.value) -def test_typedef(): +def test_typedef(cs: cstruct) -> None: cdef = """ typedef uint32 test; """ - cs = cstruct.cstruct() cs.load(cdef) assert cs.test == cs.uint32 assert cs.resolve("test") == cs.uint32 -def test_lookups(compiled): +def test_lookups(cs: cstruct, compiled: bool) -> None: cdef = """ #define test_1 1 #define test_2 2 $a = {'test_1': 3, 'test_2': 4} """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert cs.lookups["a"] == {1: 3, 2: 4} -def test_default_constructors(compiled): +def test_default_constructors(cs: cstruct, compiled: bool) -> None: cdef = """ enum Enum { a = 0, @@ -186,7 +138,6 @@ def test_default_constructors(compiled): Flag t_flag_array[2]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -207,8 +158,11 @@ def test_default_constructors(compiled): assert obj.dumps() == b"\x00" * 54 + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + -def test_default_constructors_dynamic(compiled): +def test_default_constructors_dynamic(cs: cstruct, compiled: bool) -> None: cdef = """ enum Enum { a = 0, @@ -234,10 +188,12 @@ def test_default_constructors_dynamic(compiled): Flag t_flag_array_d[x]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) + assert verify_compiled(cs.test, compiled) + obj = cs.test() + assert obj.t_int_array_n == obj.t_int_array_d == [] assert obj.t_bytesint_array_n == obj.t_bytesint_array_d == [] assert obj.t_char_array_n == obj.t_char_array_d == b"" @@ -246,8 +202,11 @@ def test_default_constructors_dynamic(compiled): assert obj.t_flag_array_n == obj.t_flag_array_d == [] assert obj.dumps() == b"\x00" * 19 + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) -def test_config_flag_nocompile(compiled): + +def test_config_flag_nocompile(cs: cstruct, compiled: bool) -> None: cdef = """ struct compiled_global { @@ -260,21 +219,19 @@ def test_config_flag_nocompile(compiled): uint32 a; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.compiled_global, compiled) assert verify_compiled(cs.never_compiled, False) -def test_compiler_slicing_multiple(compiled): +def test_compiler_slicing_multiple(cs: cstruct, compiled: bool) -> None: cdef = """ struct compile_slicing { char single; char multiple[2]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.compile_slicing, compiled) @@ -284,13 +241,12 @@ def test_compiler_slicing_multiple(compiled): assert obj.multiple == b"\x02\x03" -def test_underscores_attribute(compiled): +def test_underscores_attribute(cs: cstruct, compiled: bool) -> None: cdef = """ struct __test { uint32 test_val; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.__test, compiled) @@ -300,28 +256,24 @@ def test_underscores_attribute(compiled): assert obj.test_val == 1337 -def test_half_compiled_struct(): - from dissect.cstruct import RawType - - class OffByOne(RawType): - def __init__(self, cstruct_obj): - self._t = cstruct_obj.uint64 - super().__init__(cstruct_obj, "OffByOne", 8) - - def _read(self, stream, context=None): - return self._t._read(stream, context) + 1 +def test_half_compiled_struct(cs: cstruct) -> None: + class OffByOne(int, BaseType): + type: BaseType - def _write(self, stream, data): - return self._t._write(stream, data - 1) + @classmethod + def _read(cls, stream: BinaryIO, context: dict | None = None) -> OffByOne: + return cls(cls.type._read(stream, context) + 1) - def default(self): - return 0 + @classmethod + def _write(cls, stream: BinaryIO, data: int) -> OffByOne: + return cls(cls.type._write(stream, data - 1)) - cs = cstruct.cstruct() # Add an unsupported type for the cstruct compiler # so that it returns the original struct, # only partially compiling the struct. - cs.addtype("offbyone", OffByOne(cs)) + offbyone = cs._make_type("offbyone", (OffByOne,), 8, attrs={"type": cs.uint64}) + cs.add_type("offbyone", offbyone) + cdef = """ struct uncompiled { uint32 a; @@ -351,20 +303,19 @@ def default(self): assert obj.dumps() == buf -def test_cstruct_bytearray(): +def test_cstruct_bytearray(cs: cstruct) -> None: cdef = """ struct test { uint8 a; }; """ - cs = cstruct.cstruct() cs.load(cdef) obj = cs.test(bytearray([10])) assert obj.a == 10 -def test_multipart_type_name(): +def test_multipart_type_name(cs: cstruct) -> None: cdef = """ enum TestEnum : unsigned int { A = 0, @@ -376,20 +327,18 @@ def test_multipart_type_name(): unsigned long long b; }; """ - cs = cstruct.cstruct() cs.load(cdef) assert cs.TestEnum.type == cs.resolve("unsigned int") - assert cs.test.fields[0].type == cs.resolve("unsigned int") - assert cs.test.fields[1].type == cs.resolve("unsigned long long") + assert cs.test.__fields__[0].type == cs.resolve("unsigned int") + assert cs.test.__fields__[1].type == cs.resolve("unsigned long long") with pytest.raises(ResolveError) as exc: cdef = """ - struct test { + struct test1 { unsigned long long unsigned a; }; """ - cs = cstruct.cstruct() cs.load(cdef) with pytest.raises(ResolveError) as exc: @@ -399,20 +348,19 @@ def test_multipart_type_name(): B = 1 }; """ - cs = cstruct.cstruct() cs.load(cdef) assert str(exc.value) == "Unknown type unsigned int and more" -def test_dunder_bytes(): +def test_dunder_bytes(cs: cstruct) -> None: cdef = """ struct test { DWORD a; QWORD b; }; """ - cs = cstruct.cstruct(endian=">") + cs.endian = ">" cs.load(cdef) a = cs.test(a=0xBADC0DE, b=0xACCE55ED) @@ -421,15 +369,13 @@ def test_dunder_bytes(): assert bytes(a) == b"\x0b\xad\xc0\xde\x00\x00\x00\x00\xac\xce\x55\xed" -@pytest.mark.parametrize("compiled", [True, False]) -def test_array_of_null_terminated_strings(compiled): +def test_array_of_null_terminated_strings(cs: cstruct, compiled: bool) -> None: cdef = """ struct args { uint32 argc; char argv[argc][]; } """ - cs = cstruct.cstruct(endian="<") cs.load(cdef, compiled=compiled) assert verify_compiled(cs.args, compiled) @@ -443,7 +389,7 @@ def test_array_of_null_terminated_strings(compiled): with pytest.raises(ParserError) as exc: cdef = """ - struct args { + struct args2 { uint32 argc; char argv[][argc]; } @@ -453,15 +399,13 @@ def test_array_of_null_terminated_strings(compiled): assert str(exc.value) == "Depth required for multi-dimensional array" -@pytest.mark.parametrize("compiled", [True, False]) -def test_array_of_size_limited_strings(compiled): +def test_array_of_size_limited_strings(cs: cstruct, compiled: bool) -> None: cdef = """ struct args { uint32 argc; char argv[argc][8]; } """ - cs = cstruct.cstruct(endian="<") cs.load(cdef, compiled=compiled) assert verify_compiled(cs.args, compiled) @@ -476,14 +420,12 @@ def test_array_of_size_limited_strings(compiled): assert obj.argv[3] == b"sit amet" -@pytest.mark.parametrize("compiled", [True, False]) -def test_array_three_dimensional(compiled): +def test_array_three_dimensional(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint8 a[2][2][2]; } """ - cs = cstruct.cstruct(endian="<") cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -503,8 +445,7 @@ def test_array_three_dimensional(compiled): assert obj.dumps() == buf -@pytest.mark.parametrize("compiled", [True, False]) -def test_nested_array_of_variable_size(compiled: bool): +def test_nested_array_of_variable_size(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint8 outer; @@ -513,7 +454,6 @@ def test_nested_array_of_variable_size(compiled: bool): uint8 a[outer][medior][inner]; } """ - cs = cstruct.cstruct(endian="<") cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -534,13 +474,12 @@ def test_nested_array_of_variable_size(compiled: bool): assert obj.dumps() == buf -def test_report_array_size_mismatch(): +def test_report_array_size_mismatch(cs: cstruct) -> None: cdef = """ struct test { uint8 a[2]; }; """ - cs = cstruct.cstruct(endian=">") cs.load(cdef) a = cs.test(a=[1, 2, 3]) @@ -549,8 +488,7 @@ def test_report_array_size_mismatch(): a.dumps() -@pytest.mark.parametrize("compiled", [True, False]) -def test_reserved_keyword(compiled: bool): +def test_reserved_keyword(cs: cstruct, compiled: bool) -> None: cdef = """ struct in { uint8 a; @@ -564,7 +502,6 @@ def test_reserved_keyword(compiled: bool): uint8 a; }; """ - cs = cstruct.cstruct(endian="<") cs.load(cdef, compiled=compiled) for name in ["in", "class", "for"]: @@ -574,29 +511,74 @@ def test_reserved_keyword(compiled: bool): assert cs.resolve(name)(b"\x01").a == 1 -def test_typedef_types(): +def test_array_class_name(cs: cstruct) -> None: cdef = """ - typedef char uuid_t[16]; - typedef uint32 *ptr; - struct test { - uuid_t uuid; - ptr ptr; + uint8 a[2]; + }; + + struct test2 { + uint8 a; + uint8 b[a + 1]; }; """ - cs = cstruct.cstruct(pointer="uint8") cs.load(cdef) - assert isinstance(cs.uuid_t, Array) - assert cs.uuid_t(b"\x01" * 16) == b"\x01" * 16 + assert cs.test.__fields__[0].type.__name__ == "uint8[2]" + assert cs.test2.__fields__[1].type.__name__ == "uint8[a + 1]" + + +def test_size_and_aligment(cs: cstruct) -> None: + test = cs._make_int_type("test", 1, False, alignment=8) + assert test.size == 1 + assert test.alignment == 8 + + test = cs._make_packed_type("test", "B", int, alignment=8) + assert test.size == 1 + assert test.alignment == 8 + + +def test_dynamic_substruct_size(cs: cstruct) -> None: + cdef = """ + struct { + int32 len; + char str[len]; + } sub; + + struct { + sub data[1]; + } test; + """ + cs.load(cdef) + + assert cs.sub.dynamic + assert cs.test.dynamic + + +def test_dumps_write_overload(cs: cstruct) -> None: + assert cs.uint8.dumps(1) == cs.uint8(1).dumps() == b"\x01" + + fh = BytesIO() + cs.uint8.write(fh, 1) + assert fh.getvalue() == b"\x01" + cs.uint8(2).write(fh) + assert fh.getvalue() == b"\x01\x02" + + +def test_linked_list(cs: cstruct) -> None: + cdef = """ + struct node { + uint16 data; + node* next; + }; + """ + cs.pointer = cs.uint16 + cs.load(cdef) - assert isinstance(cs.ptr, Pointer) - assert cs.ptr(b"\x01AAAA") == 1 - assert cs.ptr(b"\x01AAAA").dereference() == 0x41414141 + assert cs.node.__fields__[1].type.type == cs.node - obj = cs.test(b"\x01" * 16 + b"\x11AAAA") - assert obj.uuid == b"\x01" * 16 - assert obj.ptr.dereference() == 0x41414141 + obj = cs.node(b"\x01\x00\x04\x00\x02\x00\x00\x00") + assert repr(obj) == ">" - with pytest.raises(ParserError, match="line 1: typedefs cannot have bitfields"): - cs.load("""typedef uint8 with_bits : 4;""") + assert obj.data == 1 + assert obj.next.data == 2 diff --git a/tests/test_bitbuffer.py b/tests/test_bitbuffer.py index 88db91f..e3e556d 100644 --- a/tests/test_bitbuffer.py +++ b/tests/test_bitbuffer.py @@ -1,14 +1,12 @@ from io import BytesIO import pytest -from dissect import cstruct from dissect.cstruct.bitbuffer import BitBuffer +from dissect.cstruct.cstruct import cstruct -def test_bitbuffer_read(): - cs = cstruct.cstruct() - +def test_bitbuffer_read(cs: cstruct) -> None: bb = BitBuffer(BytesIO(b"\xff"), "<") assert bb.read(cs.uint8, 8) == 0b11111111 diff --git a/tests/test_bitfield.py b/tests/test_bitfield.py index 4590b12..8013d5d 100644 --- a/tests/test_bitfield.py +++ b/tests/test_bitfield.py @@ -1,10 +1,11 @@ import pytest -from dissect import cstruct + +from dissect.cstruct.cstruct import cstruct from .utils import verify_compiled -def test_bitfield(compiled): +def test_bitfield(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:4; @@ -17,7 +18,6 @@ def test_bitfield(compiled): uint32 h; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -37,7 +37,7 @@ def test_bitfield(compiled): assert obj.dumps() == buf -def test_bitfield_consecutive(compiled): +def test_bitfield_consecutive(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:4; @@ -52,7 +52,6 @@ def test_bitfield_consecutive(compiled): uint16 _pad2; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -72,7 +71,7 @@ def test_bitfield_consecutive(compiled): assert obj.dumps() == buf -def test_struct_after_bitfield(compiled): +def test_struct_after_bitfield(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:4; @@ -89,7 +88,6 @@ def test_struct_after_bitfield(compiled): uint32 h; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -109,7 +107,7 @@ def test_struct_after_bitfield(compiled): assert obj.dumps() == buf -def test_bitfield_be(compiled): +def test_bitfield_be(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:4; @@ -123,7 +121,7 @@ def test_bitfield_be(compiled): uint32 i; }; """ - cs = cstruct.cstruct(endian=">") + cs.endian = ">" cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -144,7 +142,7 @@ def test_bitfield_be(compiled): assert obj.dumps() == buf -def test_bitfield_straddle(compiled): +def test_bitfield_straddle(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:12; @@ -156,7 +154,6 @@ def test_bitfield_straddle(compiled): uint32 g; }; """ - cs = cstruct.cstruct() with pytest.raises(ValueError) as exc: cs.load(cdef, compiled=compiled) @@ -164,7 +161,7 @@ def test_bitfield_straddle(compiled): assert str(exc.value) == "Straddled bit fields are unsupported" -def test_bitfield_write(compiled): +def test_bitfield_write(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:1; @@ -174,7 +171,6 @@ def test_bitfield_write(compiled): uint16 e:3; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) obj = cs.test() @@ -187,7 +183,7 @@ def test_bitfield_write(compiled): assert obj.dumps() == b"\x03\x00\xff\x00\x00\x00\x1f\x00" -def test_bitfield_write_be(compiled): +def test_bitfield_write_be(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a:1; @@ -197,7 +193,7 @@ def test_bitfield_write_be(compiled): uint16 e:3; }; """ - cs = cstruct.cstruct(endian=">") + cs.endian = ">" cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -212,7 +208,7 @@ def test_bitfield_write_be(compiled): assert obj.dumps() == b"\xc0\x00\x00\x00\x00\xff\xf8\x00" -def test_bitfield_with_enum_or_flag(compiled): +def test_bitfield_with_enum_or_flag(cs: cstruct, compiled: bool) -> None: cdef = """ flag Flag8 : uint8 { A = 1, @@ -231,7 +227,7 @@ def test_bitfield_with_enum_or_flag(compiled): Flag8 d:4; }; """ - cs = cstruct.cstruct(endian=">") + cs.endian = ">" cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -247,7 +243,7 @@ def test_bitfield_with_enum_or_flag(compiled): assert obj.dumps() == buf -def test_bitfield_char(compiled): +def test_bitfield_char(cs: cstruct, compiled: bool) -> None: cdef = """ struct test { uint16 a : 4; @@ -256,8 +252,6 @@ def test_bitfield_char(compiled): char d[4]; }; """ - - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) diff --git a/tests/test_bytesinteger.py b/tests/test_bytesinteger.py deleted file mode 100644 index 0f26fe2..0000000 --- a/tests/test_bytesinteger.py +++ /dev/null @@ -1,260 +0,0 @@ -import pytest -from dissect import cstruct - -from dissect.cstruct.types import BytesInteger - -from .utils import verify_compiled - - -def test_bytesinteger_unsigned(): - cs = cstruct.cstruct() - - assert cs.uint24(b"AAA") == 0x414141 - assert cs.uint24(b"\xff\xff\xff") == 0xFFFFFF - assert cs.uint24[4](b"AAABBBCCCDDD") == [0x414141, 0x424242, 0x434343, 0x444444] - assert cs.uint48(b"AAAAAA") == 0x414141414141 - assert cs.uint48(b"\xff\xff\xff\xff\xff\xff") == 0xFFFFFFFFFFFF - assert cs.uint48[4](b"AAAAAABBBBBBCCCCCCDDDDDD") == [0x414141414141, 0x424242424242, 0x434343434343, 0x444444444444] - - uint40 = BytesInteger(cs, "uint40", 5, signed=False) - assert uint40(b"AAAAA") == 0x4141414141 - assert uint40(b"\xff\xff\xff\xff\xff") == 0xFFFFFFFFFF - assert uint40[2](b"AAAAABBBBB") == [0x4141414141, 0x4242424242] - assert uint40[None](b"AAAAA\x00") == [0x4141414141] - - uint128 = BytesInteger(cs, "uint128", 16, signed=False) - assert uint128(b"A" * 16) == 0x41414141414141414141414141414141 - assert uint128(b"\xff" * 16) == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF - assert uint128[2](b"A" * 16 + b"B" * 16) == [0x41414141414141414141414141414141, 0x42424242424242424242424242424242] - assert uint128[None](b"AAAAAAAAAAAAAAAA\x00") == [0x41414141414141414141414141414141] - - -def test_bytesinteger_signed(): - cs = cstruct.cstruct() - - assert cs.int24(b"\xff\x00\x00") == 255 - assert cs.int24(b"\xff\xff\xff") == -1 - assert cs.int24[4](b"\xff\xff\xff\xfe\xff\xff\xfd\xff\xff\xfc\xff\xff") == [-1, -2, -3, -4] - - int40 = BytesInteger(cs, "int40", 5, signed=True) - assert int40(b"AAAAA") == 0x4141414141 - assert int40(b"\xff\xff\xff\xff\xff") == -1 - assert int40[2](b"\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xff") == [-1, -2] - - int128 = BytesInteger(cs, "int128", 16, signed=True) - assert int128(b"A" * 16) == 0x41414141414141414141414141414141 - assert int128(b"\xff" * 16) == -1 - assert int128[2](b"\xff" * 16 + b"\xfe" + b"\xff" * 15) == [-1, -2] - - -def test_bytesinteger_unsigned_be(): - cs = cstruct.cstruct() - cs.endian = ">" - - assert cs.uint24(b"\x00\x00\xff") == 255 - assert cs.uint24(b"\xff\xff\xff") == 0xFFFFFF - assert cs.uint24[3](b"\x00\x00\xff\x00\x00\xfe\x00\x00\xfd") == [255, 254, 253] - - uint40 = BytesInteger(cs, "uint40", 5, signed=False) - assert uint40(b"\x00\x00\x00\x00\xff") == 255 - assert uint40(b"\xff\xff\xff\xff\xff") == 0xFFFFFFFFFF - assert uint40[2](b"\x00\x00\x00\x00A\x00\x00\x00\x00B") == [0x41, 0x42] - - uint128 = BytesInteger(cs, "uint128", 16, signed=False) - assert uint128(b"\x00" * 15 + b"\xff") == 255 - assert uint128(b"\xff" * 16) == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF - assert uint128[2](b"\x00" * 15 + b"A" + b"\x00" * 15 + b"B") == [0x41, 0x42] - - -def test_bytesinteger_signed_be(): - cs = cstruct.cstruct() - cs.endian = ">" - - assert cs.int24(b"\x00\x00\xff") == 255 - assert cs.int24(b"\xff\xff\x01") == -255 - assert cs.int24[3](b"\xff\xff\x01\xff\xff\x02\xff\xff\x03") == [-255, -254, -253] - - int40 = BytesInteger(cs, "int40", 5, signed=True) - assert int40(b"\x00\x00\x00\x00\xff") == 255 - assert int40(b"\xff\xff\xff\xff\xff") == -1 - assert int40(b"\xff\xff\xff\xff\x01") == -255 - assert int40[2](b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe") == [-1, -2] - - int128 = BytesInteger(cs, "int128", 16, signed=True) - assert int128(b"\x00" * 15 + b"\xff") == 255 - assert int128(b"\xff" * 16) == -1 - assert int128(b"\xff" * 15 + b"\x01") == -255 - assert int128[2](b"\xff" * 16 + b"\xff" * 15 + b"\xfe") == [-1, -2] - - -def test_bytesinteger_struct_signed(compiled): - cdef = """ - struct test { - int24 a; - int24 b[2]; - int24 len; - int24 dync[len]; - int24 c; - int24 d[3]; - int128 e; - int128 f[2]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"AAABBBCCC\x02\x00\x00DDDEEE\xff\xff\xff\x01\xff\xff\x02\xff\xff\x03\xff\xff" - buf += b"A" * 16 - buf += b"\xff" * 16 + b"\x01" + b"\xff" * 15 - obj = cs.test(buf) - - assert obj.a == 0x414141 - assert obj.b == [0x424242, 0x434343] - assert obj.len == 0x02 - assert obj.dync == [0x444444, 0x454545] - assert obj.c == -1 - assert obj.d == [-255, -254, -253] - assert obj.e == 0x41414141414141414141414141414141 - assert obj.f == [-1, -255] - assert obj.dumps() == buf - - -def test_bytesinteger_struct_unsigned(compiled): - cdef = """ - struct test { - uint24 a; - uint24 b[2]; - uint24 len; - uint24 dync[len]; - uint24 c; - uint128 d; - uint128 e[2]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"AAABBBCCC\x02\x00\x00DDDEEE\xff\xff\xff" - buf += b"A" * 16 - buf += b"A" + b"\x00" * 15 + b"B" + b"\x00" * 15 - obj = cs.test(buf) - - assert obj.a == 0x414141 - assert obj.b == [0x424242, 0x434343] - assert obj.len == 0x02 - assert obj.dync == [0x444444, 0x454545] - assert obj.c == 0xFFFFFF - assert obj.d == 0x41414141414141414141414141414141 - assert obj.e == [0x41, 0x42] - assert obj.dumps() == buf - - -def test_bytesinteger_struct_signed_be(compiled): - cdef = """ - struct test { - int24 a; - int24 b[2]; - int24 len; - int24 dync[len]; - int24 c; - int24 d[3]; - int128 e; - int128 f[2]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - cs.endian = ">" - - assert verify_compiled(cs.test, compiled) - - buf = b"AAABBBCCC\x00\x00\x02DDDEEE\xff\xff\xff\xff\xff\x01\xff\xff\x02\xff\xff\x03" - buf += b"A" * 16 - buf += b"\x00" * 15 + b"A" + b"\x00" * 15 + b"B" - obj = cs.test(buf) - - assert obj.a == 0x414141 - assert obj.b == [0x424242, 0x434343] - assert obj.len == 0x02 - assert obj.dync == [0x444444, 0x454545] - assert obj.c == -1 - assert obj.d == [-255, -254, -253] - assert obj.e == 0x41414141414141414141414141414141 - assert obj.f == [0x41, 0x42] - assert obj.dumps() == buf - - -def test_bytesinteger_struct_unsigned_be(compiled): - cdef = """ - struct test { - uint24 a; - uint24 b[2]; - uint24 len; - uint24 dync[len]; - uint24 c; - uint128 d; - uint128 e[2]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - cs.endian = ">" - - assert verify_compiled(cs.test, compiled) - - buf = b"AAABBBCCC\x00\x00\x02DDDEEE\xff\xff\xff" - buf += b"\xff" * 16 - buf += b"\x00" * 14 + b"AA" + b"\x00" * 14 + b"BB" - obj = cs.test(buf) - - assert obj.a == 0x414141 - assert obj.b == [0x424242, 0x434343] - assert obj.len == 0x02 - assert obj.dync == [0x444444, 0x454545] - assert obj.c == 0xFFFFFF - assert obj.d == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF - assert obj.e == [0x4141, 0x4242] - assert obj.dumps() == buf - - -def test_bytesinteger_range(): - cs = cstruct.cstruct() - int8 = BytesInteger(cs, "int8", 1, signed=True) - uint8 = BytesInteger(cs, "uint8", 1, signed=False) - int16 = BytesInteger(cs, "int16", 2, signed=True) - int24 = BytesInteger(cs, "int24", 3, signed=True) - int128 = BytesInteger(cs, "int128", 16, signed=True) - int8.dumps(127) - int8.dumps(-128) - uint8.dumps(255) - uint8.dumps(0) - int16.dumps(-32768) - int16.dumps(32767) - int24.dumps(-8388608) - int24.dumps(8388607) - int128.dumps(-(2**127) + 1) - int128.dumps(2**127 - 1) - with pytest.raises(OverflowError): - int8.dumps(-129) - with pytest.raises(OverflowError): - int8.dumps(128) - with pytest.raises(OverflowError): - uint8.dumps(-1) - with pytest.raises(OverflowError): - uint8.dumps(256) - with pytest.raises(OverflowError): - int16.dumps(-32769) - with pytest.raises(OverflowError): - int16.dumps(32768) - with pytest.raises(OverflowError): - int24.dumps(-8388609) - with pytest.raises(OverflowError): - int24.dumps(8388608) - with pytest.raises(OverflowError): - int128.dumps(-(2**127) - 1) - with pytest.raises(OverflowError): - int128.dumps(2**127) diff --git a/tests/test_compiler.py b/tests/test_compiler.py new file mode 100644 index 0000000..8946673 --- /dev/null +++ b/tests/test_compiler.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +from operator import itemgetter +from textwrap import dedent +from typing import Iterator +from unittest.mock import Mock + +import pytest + +from dissect.cstruct import compiler +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.types.base import MetaType +from dissect.cstruct.types.enum import Enum +from dissect.cstruct.types.structure import Field + + +def f(field_type: MetaType, offset: int | None = 0, name: str = "") -> Field: + return Field(name, field_type, offset=offset) + + +def strip_fields(info: Iterator[tuple[Field, int, str]]) -> list[tuple[int, str]]: + return list(map(itemgetter(1, 2), info)) + + +def mkfmt(info: Iterator[tuple[Field, int, str]]) -> str: + return "".join(f"{count}{char}" for _, count, char in info) + + +@pytest.fixture +def TestEnum(cs: cstruct) -> type[Enum]: + return cs._make_enum("Test", cs.uint8, {"a": 1}) + + +def test_generate_struct_info(cs: cstruct, TestEnum: type[Enum]) -> None: + fields = [f(cs.uint8), f(cs.int16), f(cs.uint32), f(cs.int64)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B"), (1, "h"), (1, "I"), (1, "q")] + + fields = [f(cs.uint8[4]), f(cs.int16[4]), f(cs.uint32[4]), f(cs.int64[4])] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(4, "B"), (4, "h"), (4, "I"), (4, "q")] + + fields = [f(cs.char), f(cs.wchar), f(cs.uint24), f(cs.int128)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "x"), (2, "x"), (3, "x"), (16, "x")] + + fields = [f(cs.char[2]), f(cs.wchar[2]), f(cs.uint24[2]), f(cs.int128[2])] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(2, "x"), (4, "x"), (6, "x"), (32, "x")] + + fields = [f(cs.char), f(cs.char[2]), f(cs.char)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "x"), (2, "x"), (1, "x")] + + fields = [f(cs.uint8), f(cs.void), f(cs.int16)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B"), (1, "h")] + + fields = [f(cs.uint8), f(cs.uint16), f(cs.char[0])] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B"), (1, "H"), (0, "x")] + + cs.pointer = cs.uint64 + TestPointer = cs._make_pointer(TestEnum) + fields = [f(TestEnum), f(TestPointer)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B"), (1, "Q")] + + +def test_generate_struct_info_offsets(cs: cstruct) -> None: + fields = [f(cs.uint8, 0), f(cs.uint8, 4), f(cs.uint8[2], 5), f(cs.uint8, 8)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B"), (3, "x"), (1, "B"), (2, "B"), (1, "x"), (1, "B")] + + # Different starting offsets are handled in the field reading loop of the compilation + fields = [f(cs.uint8, 4)] + fmt = strip_fields(compiler._generate_struct_info(cs, fields)) + assert fmt == [(1, "B")] + + +@pytest.mark.parametrize( + "fields, fmt", + [ + ([(None, 1, "B"), (None, 3, "B")], "4B"), + ([(None, 1, "B"), (None, 3, "B"), (None, 2, "H")], "4B2H"), + ([(None, 1, "B"), (None, 0, "x")], "B"), + ([(None, 1, "B"), (None, 0, "x"), (None, 2, "H")], "B2H"), + ([(None, 1, "B"), (None, 0, "x"), (None, 2, "x"), (None, 1, "H")], "B2xH"), + ], +) +def test_optimize_struct_fmt(fields: list[tuple], fmt: str) -> None: + assert compiler._optimize_struct_fmt(fields) == fmt + + +def test_generate_packed_read(cs: cstruct) -> None: + fields = [ + f(cs.uint8, name="a"), + f(cs.int16, name="b"), + f(cs.uint32, name="c"), + f(cs.int64, name="d"), + ] + code = next(compiler._ReadSourceGenerator(cs, fields)._generate_packed(fields)) + + expected = """ + buf = stream.read(15) + if len(buf) != 15: raise EOFError() + data = _struct(cls.cs.endian, "BhIq").unpack(buf) + + r["a"] = type.__call__(_0, data[0]) + s["a"] = 1 + + r["b"] = type.__call__(_1, data[1]) + s["b"] = 2 + + r["c"] = type.__call__(_2, data[2]) + s["c"] = 4 + + r["d"] = type.__call__(_3, data[3]) + s["d"] = 8 + """ + + assert code == dedent(expected) + + +def test_generate_packed_read_array(cs: cstruct) -> None: + fields = [ + f(cs.uint8[2], name="a"), + f(cs.int16[3], name="b"), + f(cs.uint32[4], name="c"), + f(cs.int64[5], name="d"), + ] + code = next(compiler._ReadSourceGenerator(cs, fields)._generate_packed(fields)) + + expected = """ + buf = stream.read(64) + if len(buf) != 64: raise EOFError() + data = _struct(cls.cs.endian, "2B3h4I5q").unpack(buf) + + _t = _0 + _et = _t.type + r["a"] = type.__call__(_t, [type.__call__(_et, e) for e in data[0:2]]) + s["a"] = 2 + + _t = _1 + _et = _t.type + r["b"] = type.__call__(_t, [type.__call__(_et, e) for e in data[2:5]]) + s["b"] = 6 + + _t = _2 + _et = _t.type + r["c"] = type.__call__(_t, [type.__call__(_et, e) for e in data[5:9]]) + s["c"] = 16 + + _t = _3 + _et = _t.type + r["d"] = type.__call__(_t, [type.__call__(_et, e) for e in data[9:14]]) + s["d"] = 40 + """ + + assert code == dedent(expected) + + +def test_generate_packed_read_byte_types(cs: cstruct) -> None: + fields = [ + f(cs.char, name="a"), + f(cs.char[2], name="b"), + f(cs.wchar, name="c"), + f(cs.wchar[2], name="d"), + f(cs.int24, name="e"), + f(cs.int24[2], name="f"), + ] + code = next(compiler._ReadSourceGenerator(cs, fields)._generate_packed(fields)) + + expected = """ + buf = stream.read(18) + if len(buf) != 18: raise EOFError() + data = _struct(cls.cs.endian, "18x").unpack(buf) + + r["a"] = type.__call__(_0, buf[0:1]) + s["a"] = 1 + + r["b"] = type.__call__(_1, buf[1:3]) + s["b"] = 2 + + r["c"] = _2(buf[3:5]) + s["c"] = 2 + + r["d"] = _3(buf[5:9]) + s["d"] = 4 + + r["e"] = _4(buf[9:12]) + s["e"] = 3 + + _t = _5 + _et = _t.type + _b = buf[12:18] + r["f"] = type.__call__(_t, [_et(_b[i:i + 3]) for i in range(0, 6, 3)]) + s["f"] = 6 + """ + + assert code == dedent(expected) + + +def test_generate_packed_read_composite_types(cs: cstruct, TestEnum: type[Enum]) -> None: + cs.pointer = cs.uint64 + TestPointer = cs._make_pointer(TestEnum) + + fields = [ + f(TestEnum, name="a"), + f(TestPointer, name="b"), + f(cs.void), + f(TestEnum[2], name="c"), + ] + code = next(compiler._ReadSourceGenerator(cs, fields)._generate_packed(fields)) + + expected = """ + buf = stream.read(11) + if len(buf) != 11: raise EOFError() + data = _struct(cls.cs.endian, "BQ2B").unpack(buf) + + r["a"] = type.__call__(_0, data[0]) + s["a"] = 1 + + _pt = _1 + r["b"] = _pt.__new__(_pt, data[1], stream, r) + s["b"] = 8 + + _t = _2 + _et = _t.type + r["c"] = type.__call__(_t, [type.__call__(_et, e) for e in data[2:4]]) + s["c"] = 2 + """ + + assert code == dedent(expected) + + +def test_generate_packed_read_offsets(cs: cstruct) -> None: + fields = [ + f(cs.uint8, name="a"), + f(cs.uint8, 8, name="b"), + ] + code = next(compiler._ReadSourceGenerator(cs, fields)._generate_packed(fields)) + + expected = """ + buf = stream.read(9) + if len(buf) != 9: raise EOFError() + data = _struct(cls.cs.endian, "B7xB").unpack(buf) + + r["a"] = type.__call__(_0, data[0]) + s["a"] = 1 + + r["b"] = type.__call__(_1, data[1]) + s["b"] = 1 + """ + + assert code == dedent(expected) + + +def test_generate_structure_read(cs: cstruct) -> None: + mock_type = Mock() + mock_type.__anonymous__ = False + + field = Field("a", mock_type) + code = next(compiler._ReadSourceGenerator(cs, [field])._generate_structure(field)) + + expected = """ + _s = stream.tell() + r["a"] = _0._read(stream, context=r) + s["a"] = stream.tell() - _s + """ + + assert code == dedent(expected) + + +def test_generate_structure_read_anonymous(cs: cstruct) -> None: + mock_type = Mock() + mock_type.__anonymous__ = True + + field = Field("a", mock_type) + code = next(compiler._ReadSourceGenerator(cs, [field])._generate_structure(field)) + + expected = """ + _s = stream.tell() + r["a"] = _0._read(stream, context=r) + s["a"] = stream.tell() - _s + """ + + assert code == dedent(expected) + + +def test_generate_array_read(cs: cstruct) -> None: + field = Field("a", Mock()) + code = next(compiler._ReadSourceGenerator(cs, [field])._generate_array(field)) + + expected = """ + _s = stream.tell() + r["a"] = _0._read(stream, context=r) + s["a"] = stream.tell() - _s + """ + + assert code == dedent(expected) + + +def test_generate_bits_read(cs: cstruct, TestEnum: type[Enum]) -> None: + field = Field("a", cs.int8, 2) + code = next(compiler._ReadSourceGenerator(cs, [field])._generate_bits(field)) + + expected = """ + _t = _0 + r["a"] = type.__call__(_t, bit_reader.read(_t, 2)) + """ + + assert code == dedent(expected) + + field = Field("b", TestEnum, 2) + code = next(compiler._ReadSourceGenerator(cs, [field])._generate_bits(field)) + + expected = """ + _t = _0 + r["b"] = type.__call__(_t, bit_reader.read(_t.type, 2)) + """ + + assert code == dedent(expected) + + +# TODO: the rest of the compiler diff --git a/tests/test_ctypes.py b/tests/test_ctypes.py new file mode 100644 index 0000000..77bfef4 --- /dev/null +++ b/tests/test_ctypes.py @@ -0,0 +1,30 @@ +import ctypes as _ctypes +from typing import Any + +import pytest + +from dissect.cstruct import MetaType, cstruct, ctypes_type + +DUMMY_CSTRUCT = cstruct() + + +# TODO: test structure/union + + +@pytest.mark.parametrize( + "input, expected", + [ + (DUMMY_CSTRUCT.int8, _ctypes.c_int8), + (DUMMY_CSTRUCT.char, _ctypes.c_char), + (DUMMY_CSTRUCT.char[3], _ctypes.c_char * 3), + (DUMMY_CSTRUCT.int8[3], _ctypes.c_int8 * 3), + (DUMMY_CSTRUCT._make_pointer(DUMMY_CSTRUCT.int8), _ctypes.POINTER(_ctypes.c_int8)), + ], +) +def test_ctypes_type(input: MetaType, expected: Any) -> None: + assert expected == ctypes_type(input) + + +def test_ctypes_type_exception() -> None: + with pytest.raises(NotImplementedError): + ctypes_type(DUMMY_CSTRUCT.float16) diff --git a/tests/test_ctypes_type.py b/tests/test_ctypes_type.py deleted file mode 100644 index 31e2d95..0000000 --- a/tests/test_ctypes_type.py +++ /dev/null @@ -1,25 +0,0 @@ -import ctypes as _ctypes - -import pytest -from dissect import cstruct - -DUMMY_CSTRUCT = cstruct.cstruct() -PACKED_TYPE_INT8 = cstruct.PackedType(DUMMY_CSTRUCT, "int8", 1, "b") - - -@pytest.mark.parametrize( - "cstruct_type, ctypes_type", - [ - (PACKED_TYPE_INT8, _ctypes.c_int8), - (cstruct.CharType(DUMMY_CSTRUCT), _ctypes.c_char), - (cstruct.Array(DUMMY_CSTRUCT, PACKED_TYPE_INT8, 3), _ctypes.c_int8 * 3), - (cstruct.Pointer(DUMMY_CSTRUCT, PACKED_TYPE_INT8), _ctypes.POINTER(_ctypes.c_int8)), - ], -) -def test_ctypes_type(cstruct_type, ctypes_type): - assert ctypes_type == cstruct.ctypes_type(cstruct_type) - - -def test_ctypes_type_exception(): - with pytest.raises(NotImplementedError): - cstruct.ctypes_type("FAIL") diff --git a/tests/test_expression.py b/tests/test_expression.py index 1495ecf..11f67e4 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1,6 +1,6 @@ import pytest -from dissect import cstruct +from dissect.cstruct.cstruct import cstruct from dissect.cstruct.exceptions import ExpressionParserError, ExpressionTokenizerError from dissect.cstruct.expression import Expression @@ -106,7 +106,7 @@ def test_expression_failure(expression: str, exception: type, message: str) -> N parser.evaluate() -def test_sizeof(): +def test_sizeof(cs: cstruct) -> None: d = """ struct test { char a[sizeof(uint32)]; @@ -116,8 +116,7 @@ def test_sizeof(): char a[sizeof(test) * 2]; }; """ - c = cstruct.cstruct() - c.load(d) + cs.load(d) - assert len(c.test) == 4 - assert len(c.test2) == 8 + assert len(cs.test) == 4 + assert len(cs.test2) == 8 diff --git a/tests/test_flag.py b/tests/test_flag.py deleted file mode 100644 index 1c8c0c9..0000000 --- a/tests/test_flag.py +++ /dev/null @@ -1,149 +0,0 @@ -from dissect import cstruct - -from .utils import verify_compiled - - -def test_flag(): - cdef = """ - flag Test { - a, - b, - c, - d - }; - - flag Odd { - a = 2, - b, - c, - d = 32, e, f, - g - }; - """ - cs = cstruct.cstruct() - cs.load(cdef) - - assert cs.Test.a == 1 - assert cs.Test.b == 2 - assert cs.Test.c == 4 - assert cs.Test.d == 8 - - assert cs.Odd.a == 2 - assert cs.Odd.b == 4 - assert cs.Odd.c == 8 - assert cs.Odd.d == 32 - assert cs.Odd.e == 64 - assert cs.Odd.f == 128 - assert cs.Odd.g == 256 - - assert cs.Test.a == cs.Test.a - assert cs.Test.a != cs.Test.b - assert bool(cs.Test(0)) is False - assert bool(cs.Test(1)) is True - - assert cs.Test.a | cs.Test.b == 3 - assert str(cs.Test.c | cs.Test.d) == "Test.d|c" - assert repr(cs.Test.a | cs.Test.b) == "" - assert cs.Test(2) == cs.Test.b - assert cs.Test(3) == cs.Test.a | cs.Test.b - assert cs.Test.c & 12 == cs.Test.c - assert cs.Test.b & 12 == 0 - assert cs.Test.b ^ cs.Test.a == cs.Test.a | cs.Test.b - - assert ~cs.Test.a == -2 - assert str(~cs.Test.a) == "Test.d|c|b" - - -def test_flag_read(compiled): - cdef = """ - flag Test16 : uint16 { - A = 0x1, - B = 0x2 - }; - - flag Test24 : uint24 { - A = 0x1, - B = 0x2 - }; - - flag Test32 : uint32 { - A = 0x1, - B = 0x2 - }; - - struct test { - Test16 a16; - Test16 b16; - Test24 a24; - Test24 b24; - Test32 a32; - Test32 b32; - Test16 l[2]; - Test16 c16; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"\x01\x00\x02\x00\x01\x00\x00\x02\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x02\x00\x03\x00" - obj = cs.test(buf) - - assert obj.a16.enum == cs.Test16 and obj.a16.value == cs.Test16.A - assert obj.b16.enum == cs.Test16 and obj.b16.value == cs.Test16.B - assert obj.a24.enum == cs.Test24 and obj.a24.value == cs.Test24.A - assert obj.b24.enum == cs.Test24 and obj.b24.value == cs.Test24.B - assert obj.a32.enum == cs.Test32 and obj.a32.value == cs.Test32.A - assert obj.b32.enum == cs.Test32 and obj.b32.value == cs.Test32.B - - assert len(obj.l) == 2 - assert obj.l[0].enum == cs.Test16 and obj.l[0].value == cs.Test16.A - assert obj.l[1].enum == cs.Test16 and obj.l[1].value == cs.Test16.B - - assert obj.c16 == cs.Test16.A | cs.Test16.B - assert obj.c16 & cs.Test16.A - assert str(obj.c16) == "Test16.B|A" - - assert obj.dumps() == buf - - -def test_flag_anonymous(compiled): - cdef = """ - flag : uint16 { - A, - B, - C, - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert cs.A == 1 - assert cs.B == 2 - assert cs.C == 4 - - assert cs.A.name == "A" - assert cs.A.value == 1 - assert repr(cs.A) == "" - - assert repr(cs.A | cs.B) == "" - - -def test_flag_anonymous_struct(compiled): - cdef = """ - flag : uint32 { - nElements = 4 - }; - - struct test { - uint32 arr[nElements]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - test = cs.test - - t = test(b"\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0A\x00\x00\x00") - assert t.arr == [255, 0, 0, 10] diff --git a/tests/test_packedtype.py b/tests/test_packedtype.py deleted file mode 100644 index af6b0da..0000000 --- a/tests/test_packedtype.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest -from dissect import cstruct - -from dissect.cstruct.types import PackedType - -from .utils import verify_compiled - - -def test_packedtype_float(): - cs = cstruct.cstruct() - - assert cs.float16.dumps(420.69) == b"\x93^" - assert cs.float.dumps(31337.6969) == b"e\xd3\xf4F" - assert cs.float16.reads(b"\x69\x69") == 2770.0 - assert cs.float.reads(b"M0MS") == 881278648320.0 - - -def test_packedtype_float_struct(compiled): - cdef = """ - struct test { - float16 a; - float b; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"69\xb1U$G" - obj = cs.test(buf) - - assert obj.a == 0.6513671875 - assert obj.b == 42069.69140625 - - -def test_packedtype_float_struct_be(compiled): - cdef = """ - struct test { - float16 a; - float b; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - cs.endian = ">" - - assert verify_compiled(cs.test, compiled) - - buf = b"69G$U\xb1" - obj = cs.test(buf) - print(obj) - - assert obj.a == 0.388916015625 - assert obj.b == 42069.69140625 - - -def test_packedtype_range(): - cs = cstruct.cstruct() - float16 = PackedType(cs, "float16", 2, "e") - float16.dumps(-65519.999999999996) - float16.dumps(65519.999999999996) - with pytest.raises(OverflowError): - float16.dumps(-65519.999999999997) - with pytest.raises(OverflowError): - float16.dumps(65519.999999999997) diff --git a/tests/test_parser.py b/tests/test_parser.py index 232ba26..e81da46 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,9 +1,14 @@ from unittest.mock import Mock +import pytest + +from dissect.cstruct import cstruct +from dissect.cstruct.exceptions import ParserError from dissect.cstruct.parser import TokenParser +from dissect.cstruct.types import ArrayMetaType, Pointer -def test_preserve_comment_newlines(): +def test_preserve_comment_newlines() -> None: cdef = """ // normal comment #define normal_anchor @@ -23,3 +28,48 @@ def test_preserve_comment_newlines(): mock_token.match.start.return_value = data.index("#define multi_anchor") assert TokenParser._lineno(mock_token) == 9 + + +def test_typedef_types(cs: cstruct) -> None: + cdef = """ + typedef char uuid_t[16]; + typedef uint32 *ptr; + + struct test { + uuid_t uuid; + ptr ptr; + }; + """ + cs.pointer = cs.uint8 + cs.load(cdef) + + assert isinstance(cs.uuid_t, ArrayMetaType) + assert cs.uuid_t(b"\x01" * 16) == b"\x01" * 16 + + assert issubclass(cs.ptr, Pointer) + assert cs.ptr(b"\x01AAAA") == 1 + assert cs.ptr(b"\x01AAAA").dereference() == 0x41414141 + + obj = cs.test(b"\x01" * 16 + b"\x11AAAA") + assert obj.uuid == b"\x01" * 16 + assert obj.ptr.dereference() == 0x41414141 + + with pytest.raises(ParserError, match="line 1: typedefs cannot have bitfields"): + cs.load("""typedef uint8 with_bits : 4;""") + + +def test_dynamic_substruct_size(cs: cstruct) -> None: + cdef = """ + struct { + int32 len; + char str[len]; + } sub; + + struct { + sub data[1]; + } test; + """ + cs.load(cdef) + + assert cs.sub.dynamic + assert cs.test.dynamic diff --git a/tests/test_pointer.py b/tests/test_pointer.py deleted file mode 100644 index 7b80cbd..0000000 --- a/tests/test_pointer.py +++ /dev/null @@ -1,162 +0,0 @@ -from unittest.mock import patch - -import pytest -from dissect import cstruct - -from dissect.cstruct.types.pointer import PointerInstance - -from .utils import verify_compiled - - -@pytest.mark.parametrize("compiled", [True, False]) -def test_pointer_basic(compiled): - cdef = """ - struct ptrtest { - uint32 *ptr1; - uint32 *ptr2; - }; - """ - cs = cstruct.cstruct(pointer="uint16") - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.ptrtest, compiled) - assert cs.pointer is cs.uint16 - - buf = b"\x04\x00\x08\x00\x01\x02\x03\x04\x05\x06\x07\x08" - obj = cs.ptrtest(buf) - - assert obj.ptr1 != 0 - assert obj.ptr2 != 0 - assert obj.ptr1 != obj.ptr2 - assert obj.ptr1 == 4 - assert obj.ptr2 == 8 - assert obj.ptr1.dereference() == 0x04030201 - assert obj.ptr2.dereference() == 0x08070605 - - obj.ptr1 += 2 - obj.ptr2 -= 2 - assert obj.ptr1 == obj.ptr2 - assert obj.ptr1.dereference() == obj.ptr2.dereference() == 0x06050403 - - assert obj.dumps() == b"\x06\x00\x06\x00" - - with pytest.raises(cstruct.NullPointerDereference): - cs.ptrtest(b"\x00\x00\x00\x00").ptr1.dereference() - - -@pytest.mark.parametrize("compiled", [True, False]) -def test_pointer_struct(compiled): - cdef = """ - struct test { - char magic[4]; - wchar wmagic[4]; - uint8 a; - uint16 b; - uint32 c; - char string[]; - wchar wstring[]; - }; - - struct ptrtest { - test *ptr; - }; - """ - cs = cstruct.cstruct(pointer="uint16") - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - assert verify_compiled(cs.ptrtest, compiled) - assert cs.pointer is cs.uint16 - - buf = b"\x02\x00testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" - obj = cs.ptrtest(buf) - - assert obj.ptr != 0 - - assert obj.ptr.magic == b"test" - assert obj.ptr.wmagic == "test" - assert obj.ptr.a == 0x01 - assert obj.ptr.b == 0x0302 - assert obj.ptr.c == 0x07060504 - assert obj.ptr.string == b"lalala" - assert obj.ptr.wstring == "test" - - assert obj.dumps() == b"\x02\x00" - - with pytest.raises(cstruct.NullPointerDereference): - cs.ptrtest(b"\x00\x00\x00\x00").ptr.magic - - -@pytest.mark.parametrize("compiled", [True, False]) -def test_array_of_pointers(compiled): - cdef = """ - struct mainargs { - uint8_t argc; - char *args[4]; - } - """ - cs = cstruct.cstruct(pointer="uint16") - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.mainargs, compiled) - assert cs.pointer is cs.uint16 - - buf = b"\x02\x09\x00\x16\x00\x00\x00\x00\x00argument one\x00argument two\x00" - obj = cs.mainargs(buf) - - assert obj.argc == 2 - assert obj.args[2] == 0 - assert obj.args[3] == 0 - assert obj.args[0].dereference() == b"argument one" - assert obj.args[1].dereference() == b"argument two" - - -def test_pointer_arithmetic(): - inst = PointerInstance(None, None, 0, None) - assert inst._addr == 0 - - inst += 4 - assert inst._addr == 4 - - inst -= 2 - assert inst._addr == 2 - - inst *= 12 - assert inst._addr == 24 - - inst //= 2 - assert inst._addr == 12 - - inst %= 10 - assert inst._addr == 2 - - inst **= 4 - assert inst._addr == 16 - - inst <<= 1 - assert inst._addr == 32 - - inst >>= 2 - assert inst._addr == 8 - - inst &= 2 - assert inst._addr == 0 - - inst ^= 4 - assert inst._addr == 4 - - inst |= 8 - assert inst._addr == 12 - - -def test_pointer_sys_size(): - with patch("sys.maxsize", 2**64): - c = cstruct.cstruct() - assert c.pointer is c.uint64 - - with patch("sys.maxsize", 2**32): - c = cstruct.cstruct() - assert c.pointer is c.uint32 - - c = cstruct.cstruct(pointer="uint16") - assert c.pointer is c.uint16 diff --git a/tests/test_struct.py b/tests/test_struct.py deleted file mode 100644 index f08c0f8..0000000 --- a/tests/test_struct.py +++ /dev/null @@ -1,320 +0,0 @@ -from io import BytesIO - -import pytest -from dissect import cstruct - -from .utils import verify_compiled - - -def test_struct_simple(compiled): - cdef = """ - struct test { - char magic[4]; - wchar wmagic[4]; - uint8 a; - uint16 b; - uint32 c; - char string[]; - wchar wstring[]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" - obj = cs.test(buf) - - assert "magic" in obj - assert obj.magic == b"test" - assert obj["magic"] == obj.magic - assert obj.wmagic == "test" - assert obj.a == 0x01 - assert obj.b == 0x0302 - assert obj.c == 0x07060504 - assert obj.string == b"lalala" - assert obj.wstring == "test" - - with pytest.raises(AttributeError): - obj.nope - - assert obj._size("magic") == 4 - assert len(obj) == len(buf) - assert obj.dumps() == buf - - assert repr(obj) - - fh = BytesIO() - obj.write(fh) - assert fh.getvalue() == buf - - -def test_struct_simple_be(compiled): - cdef = """ - struct test { - char magic[4]; - wchar wmagic[4]; - uint8 a; - uint16 b; - uint32 c; - char string[]; - wchar wstring[]; - }; - """ - cs = cstruct.cstruct(endian=">") - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"test\x00t\x00e\x00s\x00t\x01\x02\x03\x04\x05\x06\x07lalala\x00\x00t\x00e\x00s\x00t\x00\x00" - obj = cs.test(buf) - - assert obj.magic == b"test" - assert obj.wmagic == "test" - assert obj.a == 0x01 - assert obj.b == 0x0203 - assert obj.c == 0x04050607 - assert obj.string == b"lalala" - assert obj.wstring == "test" - assert obj.dumps() == buf - - -def test_struct_definitions(compiled): - cdef = """ - struct _test { - uint32 a; - // uint32 comment - uint32 b; - } test, test1; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - assert cs._test == cs.test == cs.test1 - assert cs.test.name == "_test" - assert cs._test.name == "_test" - - assert "a" in cs.test.lookup - assert "b" in cs.test.lookup - - with pytest.raises(cstruct.ParserError): - cdef = """ - struct { - uint32 a; - }; - """ - cs.load(cdef) - - -def test_struct_expressions(compiled): - cdef = """ - #define const 1 - struct test { - uint8 flag; - uint8 data_1[(flag & 1) * 4]; - uint8 data_2[flag & (1 << 2)]; - uint8 data_3[const]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - obj = cs.test(b"\x01\x00\x01\x02\x03\xff") - assert obj.flag == 1 - assert obj.data_1 == [0, 1, 2, 3] - assert obj.data_2 == [] - assert obj.data_3 == [255] - - obj = cs.test(b"\x04\x04\x05\x06\x07\xff") - assert obj.flag == 4 - assert obj.data_1 == [] - assert obj.data_2 == [4, 5, 6, 7] - assert obj.data_3 == [255] - - -def test_struct_sizes(compiled): - cdef = """ - struct static { - uint32 test; - }; - - struct dynamic { - uint32 test[]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.static, compiled) - assert verify_compiled(cs.dynamic, compiled) - - assert len(cs.static) == 4 - - if not compiled: - cs.static.add_field("another", cs.uint32) - assert len(cs.static) == 8 - cs.static.add_field("atoffset", cs.uint32, offset=12) - assert len(cs.static) == 16 - - obj = cs.static(b"\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00") - assert obj.test == 1 - assert obj.another == 2 - assert obj.atoffset == 3 - - with pytest.raises(TypeError) as excinfo: - len(cs.dynamic) - assert str(excinfo.value) == "Dynamic size" - else: - with pytest.raises(NotImplementedError) as excinfo: - cs.static.add_field("another", cs.uint32) - assert str(excinfo.value) == "Can't add fields to a compiled structure" - - -def test_struct_nested(compiled): - cdef = """ - struct test_named { - char magic[4]; - struct { - uint32 a; - uint32 b; - } a; - struct { - char c[8]; - } b; - }; - - struct test_anonymous { - char magic[4]; - struct { - uint32 a; - uint32 b; - }; - struct { - char c[8]; - }; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test_named, compiled) - assert verify_compiled(cs.test_anonymous, compiled) - - assert len(cs.test_named) == len(cs.test_anonymous) == 20 - - data = b"zomg\x39\x05\x00\x00\x28\x23\x00\x00deadbeef" - obj = cs.test_named(data) - assert obj.magic == b"zomg" - assert obj.a.a == 1337 - assert obj.a.b == 9000 - assert obj.b.c == b"deadbeef" - assert obj.dumps() == data - - obj = cs.test_anonymous(data) - assert obj.magic == b"zomg" - assert obj.a == 1337 - assert obj.b == 9000 - assert obj.c == b"deadbeef" - assert obj.dumps() == data - - -def test_struct_write(compiled): - cdef = """ - struct test { - char magic[4]; - wchar wmagic[4]; - uint8 a; - uint16 b; - uint32 c; - char string[]; - wchar wstring[]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" - - obj = cs.test() - obj.magic = "test" - obj.wmagic = "test" - obj.a = 0x01 - obj.b = 0x0302 - obj.c = 0x07060504 - obj.string = b"lalala" - obj.wstring = "test" - - with pytest.raises(AttributeError): - obj.nope = 1 - - assert obj.dumps() == buf - - inst = cs.test( - magic=b"test", - wmagic="test", - a=0x01, - b=0x0302, - c=0x07060504, - string=b"lalala", - wstring="test", - ) - assert inst.dumps() == buf - - -def test_struct_write_be(compiled): - cdef = """ - struct test { - char magic[4]; - wchar wmagic[4]; - uint8 a; - uint16 b; - uint32 c; - char string[]; - wchar wstring[]; - }; - """ - cs = cstruct.cstruct(endian=">") - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"test\x00t\x00e\x00s\x00t\x01\x02\x03\x04\x05\x06\x07lalala\x00\x00t\x00e\x00s\x00t\x00\x00" - - obj = cs.test() - obj.magic = "test" - obj.wmagic = "test" - obj.a = 0x01 - obj.b = 0x0203 - obj.c = 0x04050607 - obj.string = b"lalala" - obj.wstring = "test" - - assert obj.dumps() == buf - - -def test_struct_write_anonymous(): - cdef = """ - struct test { - uint32 a; - union { - struct { - uint16 b1; - uint16 b2; - }; - uint32 b; - }; - uint32 c; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef) - - obj = cs.test(a=1, c=3) - assert obj.dumps() == b"\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00" diff --git a/tests/test_types_base.py b/tests/test_types_base.py new file mode 100644 index 0000000..18483fa --- /dev/null +++ b/tests/test_types_base.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import BinaryIO + +import pytest + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.exceptions import ArraySizeError +from dissect.cstruct.types.base import ArrayMetaType, BaseType + +from .utils import verify_compiled + + +def test_array_size_mismatch(cs: cstruct) -> None: + with pytest.raises(ArraySizeError): + cs.uint8[2]([1, 2, 3]).dumps() + + assert cs.uint8[2]([1, 2]).dumps() + + +def test_eof(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test_char { + char data[EOF]; + }; + + struct test_wchar { + wchar data[EOF]; + }; + + struct test_packed { + uint16 data[EOF]; + }; + + struct test_int { + uint24 data[EOF]; + }; + + enum Test : uint16 { + A = 1 + }; + + struct test_enum { + Test data[EOF]; + }; + + struct test_eof_field { + uint8 EOF; + char data[EOF]; + uint8 remainder; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test_char, compiled) + assert verify_compiled(cs.test_wchar, compiled) + assert verify_compiled(cs.test_packed, compiled) + assert verify_compiled(cs.test_int, compiled) + assert verify_compiled(cs.test_enum, compiled) + assert verify_compiled(cs.test_eof_field, compiled) + + test_char = cs.test_char(b"abc") + assert test_char.data == b"abc" + assert test_char.dumps() == b"abc" + + test_wchar = cs.test_wchar("abc".encode("utf-16-le")) + assert test_wchar.data == "abc" + assert test_wchar.dumps() == "abc".encode("utf-16-le") + + test_packed = cs.test_packed(b"\x01\x00\x02\x00") + assert test_packed.data == [1, 2] + assert test_packed.dumps() == b"\x01\x00\x02\x00" + + test_int = cs.test_int(b"\x01\x00\x00\x02\x00\x00") + assert test_int.data == [1, 2] + assert test_int.dumps() == b"\x01\x00\x00\x02\x00\x00" + + test_enum = cs.test_enum(b"\x01\x00") + assert test_enum.data == [cs.Test.A] + assert test_enum.dumps() == b"\x01\x00" + + test_eof_field = cs.test_eof_field(b"\x01a\x02") + assert test_eof_field.data == b"a" + assert test_eof_field.remainder == 2 + assert test_eof_field.dumps() == b"\x01a\x02" + + +def test_custom_array_type(cs: cstruct, compiled: bool) -> None: + class CustomType(BaseType): + def __init__(self, value): + self.value = value.upper() + + @classmethod + def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType: + length = stream.read(1)[0] + value = stream.read(length) + return type.__call__(cls, value) + + class ArrayType(BaseType, metaclass=ArrayMetaType): + @classmethod + def _read(cls, stream: BinaryIO, context: dict | None = None) -> CustomType.ArrayType: + value = cls.type._read(stream, context) + if str(cls.num_entries) == "lower": + value.value = value.value.lower() + + return value + + cs.add_custom_type("custom", CustomType) + + result = cs.custom(b"\x04asdf") + assert isinstance(result, CustomType) + assert result.value == b"ASDF" + + result = cs.custom["lower"](b"\x04asdf") + assert isinstance(result, CustomType) + assert result.value == b"asdf" + + cdef = """ + struct test { + custom a; + custom b[lower]; + }; + """ + cs.load(cdef, compiled=compiled) + + # We just don't want to compiler to blow up with a custom type + assert not cs.test.__compiled__ + + result = cs.test(b"\x04asdf\x04asdf") + assert isinstance(result.a, CustomType) + assert isinstance(result.b, CustomType) + assert result.a.value == b"ASDF" + assert result.b.value == b"asdf" diff --git a/tests/test_types_char.py b/tests/test_types_char.py new file mode 100644 index 0000000..bc93993 --- /dev/null +++ b/tests/test_types_char.py @@ -0,0 +1,45 @@ +import io + +import pytest + +from dissect.cstruct.cstruct import cstruct + + +def test_char_read(cs: cstruct) -> None: + assert cs.char(b"A") == b"A" + assert cs.char(b"AAAA\x00") == b"A" + assert cs.char(io.BytesIO(b"AAAA\x00")) == b"A" + + +def test_char_write(cs: cstruct) -> None: + assert cs.char(b"A").dumps() == b"A" + + +def test_char_array(cs: cstruct) -> None: + buf = b"AAAA\x00" + + assert cs.char[4](buf) == b"AAAA" + assert cs.char[4](io.BytesIO(buf)) == b"AAAA" + + assert cs.char[None](buf) == b"AAAA" + assert cs.char[None](io.BytesIO(buf)) == b"AAAA" + + +def test_char_array_write(cs: cstruct) -> None: + buf = b"AAAA\x00" + + assert cs.char[4](buf).dumps() == b"AAAA" + assert cs.char[None](buf).dumps() == b"AAAA\x00" + + +def test_char_eof(cs: cstruct) -> None: + with pytest.raises(EOFError): + cs.char(b"") + + with pytest.raises(EOFError): + cs.char[4](b"") + + with pytest.raises(EOFError): + cs.char[None](b"AAAA") + + assert cs.char[0](b"") == b"" diff --git a/tests/test_types_custom.py b/tests/test_types_custom.py new file mode 100644 index 0000000..e9fb98f --- /dev/null +++ b/tests/test_types_custom.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, BinaryIO + +import pytest + +from dissect.cstruct import cstruct +from dissect.cstruct.types import BaseType, MetaType + + +class EtwPointer(BaseType): + type: MetaType + size: int | None + + @classmethod + def _read(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> BaseType: + return cls.type._read(stream, context) + + @classmethod + def _read_0(cls, stream: BinaryIO, context: dict[str, Any] | None = None) -> list[BaseType]: + return cls.type._read_0(stream, context) + + @classmethod + def _write(cls, stream: BinaryIO, data: Any) -> int: + return cls.type._write(stream, data) + + @classmethod + def as_32bit(cls) -> None: + cls.type = cls.cs.uint32 + cls.size = 4 + + @classmethod + def as_64bit(cls) -> None: + cls.type = cls.cs.uint64 + cls.size = 8 + + +def test_adding_custom_type(cs: cstruct) -> None: + cs.add_custom_type("EtwPointer", EtwPointer) + + cs.EtwPointer.as_64bit() + assert cs.EtwPointer.type is cs.uint64 + assert len(cs.EtwPointer) == 8 + assert cs.EtwPointer(b"\xDE\xAD\xBE\xEF" * 2).dumps() == b"\xDE\xAD\xBE\xEF" * 2 + + cs.EtwPointer.as_32bit() + assert cs.EtwPointer.type is cs.uint32 + assert len(cs.EtwPointer) == 4 + assert cs.EtwPointer(b"\xDE\xAD\xBE\xEF" * 2).dumps() == b"\xDE\xAD\xBE\xEF" + + +def test_using_type_in_struct(cs: cstruct) -> None: + cs.add_custom_type("EtwPointer", EtwPointer) + + struct_definition = """ + struct test { + EtwPointer data; + uint64 data2; + }; + """ + + cs.load(struct_definition) + + cs.EtwPointer.as_64bit() + assert len(cs.test().data) == 8 + + with pytest.raises(EOFError): + # Input too small + cs.test(b"\xDE\xAD\xBE\xEF" * 3) + + cs.EtwPointer.as_32bit() + assert len(cs.test().data) == 4 + + assert cs.test(b"\xDE\xAD\xBE\xEF" * 3).data.dumps() == b"\xDE\xAD\xBE\xEF" diff --git a/tests/test_enum.py b/tests/test_types_enum.py similarity index 59% rename from tests/test_enum.py rename to tests/test_types_enum.py index 9f0ea4d..2981408 100644 --- a/tests/test_enum.py +++ b/tests/test_types_enum.py @@ -1,10 +1,92 @@ +from enum import Enum as StdEnum + import pytest -from dissect import cstruct + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.types.enum import Enum from .utils import verify_compiled -def test_enum(compiled): +@pytest.fixture +def TestEnum(cs: cstruct) -> type[Enum]: + return cs._make_enum("Test", cs.uint8, {"A": 1, "B": 2, "C": 3}) + + +def test_enum(cs: cstruct, TestEnum: type[Enum]) -> None: + assert issubclass(TestEnum, StdEnum) + assert TestEnum.cs is cs + assert TestEnum.type is cs.uint8 + assert TestEnum.size == 1 + assert TestEnum.alignment == 1 + + assert TestEnum.A == 1 + assert TestEnum.B == 2 + assert TestEnum.C == 3 + assert TestEnum(1) == TestEnum.A + assert TestEnum(2) == TestEnum.B + assert TestEnum(3) == TestEnum.C + assert TestEnum["A"] == TestEnum.A + assert TestEnum["B"] == TestEnum.B + assert TestEnum["C"] == TestEnum.C + + assert TestEnum(0) == 0 + assert TestEnum(0).name is None + assert TestEnum(0).value == 0 + + +def test_enum_read(TestEnum: type[Enum]) -> None: + assert TestEnum(b"\x02") == TestEnum.B + + +def test_enum_write(TestEnum: type[Enum]) -> None: + assert TestEnum.B.dumps() == b"\x02" + assert TestEnum(b"\x02").dumps() == b"\x02" + + +def test_enum_array_read(TestEnum: type[Enum]) -> None: + assert TestEnum[2](b"\x02\x03") == [TestEnum.B, TestEnum.C] + assert TestEnum[None](b"\x02\x03\x00") == [TestEnum.B, TestEnum.C] + + +def test_enum_array_write(TestEnum: type[Enum]) -> None: + assert TestEnum[2]([TestEnum.B, TestEnum.C]).dumps() == b"\x02\x03" + assert TestEnum[None]([TestEnum.B, TestEnum.C]).dumps() == b"\x02\x03\x00" + + +def test_enum_alias(cs: cstruct) -> None: + AliasEnum = cs._make_enum("Test", cs.uint8, {"A": 1, "B": 2, "C": 2}) + + assert AliasEnum.A == 1 + assert AliasEnum.B == 2 + assert AliasEnum.C == 2 + + assert AliasEnum.A.name == "A" + assert AliasEnum.B.name == "B" + assert AliasEnum.C.name == "C" + + assert AliasEnum.B == AliasEnum.C + + assert AliasEnum.B.dumps() == AliasEnum.C.dumps() + + +def test_enum_bad_type(cs: cstruct) -> None: + with pytest.raises(TypeError): + cs._make_enum("Test", cs.char, {"A": 1, "B": 2, "C": 3}) + + +def test_enum_eof(TestEnum: type[Enum]) -> None: + with pytest.raises(EOFError): + TestEnum(b"") + + with pytest.raises(EOFError): + TestEnum[2](b"\x01") + + with pytest.raises(EOFError): + TestEnum[None](b"\x01") + + +def test_enum_struct(cs: cstruct, compiled: bool) -> None: cdef = """ enum Test16 : uint16 { A = 0x1, @@ -40,7 +122,6 @@ def test_enum(compiled): Test16 expr[size * 2]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -50,19 +131,17 @@ def test_enum(compiled): buf = b"\x01\x00\x02\x00\x01\x00\x00\x02\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x02\x00" obj = cs.test(buf) - assert obj.a16.enum == cs.Test16 and obj.a16 == cs.Test16.A - assert obj.b16.enum == cs.Test16 and obj.b16 == cs.Test16.B - assert obj.a24.enum == cs.Test24 and obj.a24 == cs.Test24.A - assert obj.b24.enum == cs.Test24 and obj.b24 == cs.Test24.B - assert obj.a32.enum == cs.Test32 and obj.a32 == cs.Test32.A - assert obj.b32.enum == cs.Test32 and obj.b32 == cs.Test32.B + assert isinstance(obj.a16, cs.Test16) and obj.a16 == cs.Test16.A + assert isinstance(obj.b16, cs.Test16) and obj.b16 == cs.Test16.B + assert isinstance(obj.a24, cs.Test24) and obj.a24 == cs.Test24.A + assert isinstance(obj.b24, cs.Test24) and obj.b24 == cs.Test24.B + assert isinstance(obj.a32, cs.Test32) and obj.a32 == cs.Test32.A + assert isinstance(obj.b32, cs.Test32) and obj.b32 == cs.Test32.B assert len(obj.l) == 2 - assert obj.l[0].enum == cs.Test16 and obj.l[0] == cs.Test16.A - assert obj.l[1].enum == cs.Test16 and obj.l[1] == cs.Test16.B + assert isinstance(obj.l[0], cs.Test16) and obj.l[0] == cs.Test16.A + assert isinstance(obj.l[1], cs.Test16) and obj.l[1] == cs.Test16.B - assert "A" in cs.Test16 - assert "Foo" not in cs.Test16 assert cs.Test16(1) == cs.Test16["A"] assert cs.Test24(2) == cs.Test24.B assert cs.Test16.A != cs.Test24.A @@ -98,8 +177,13 @@ def test_enum(compiled): with pytest.raises(KeyError): obj[cs.Test32.A] + assert repr(cs.Test16.A) == "" + assert str(cs.Test16.A) == "Test16.A" + assert repr(cs.Test16(69)) == "" + assert str(cs.Test16(69)) == "Test16.69" + -def test_enum_comments(): +def test_enum_comments(cs: cstruct) -> None: cdef = """ enum Inline { hello=7, world, foo, bar }; // inline enum @@ -117,7 +201,6 @@ def test_enum_comments(): g // next }; """ - cs = cstruct.cstruct() cs.load(cdef) assert cs.Inline.hello == 7 @@ -142,7 +225,7 @@ def test_enum_comments(): assert cs.Test.a != cs.Test.b -def test_enum_name(compiled): +def test_enum_name(cs: cstruct, compiled: bool) -> None: cdef = """ enum Color: uint16 { RED = 1, @@ -157,7 +240,6 @@ def test_enum_name(compiled): uint32 hue; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.Pixel, compiled) @@ -173,13 +255,13 @@ def test_enum_name(compiled): assert pixel.color.value == 1 assert pixel.hue == 0xDDCCBBAA - # unknown enum values default to _ pixel = Pixel(b"\x00\x00\xFF\x00\xAA\xBB\xCC\xDD") - assert pixel.color.name == "Color_255" + assert pixel.color.name is None assert pixel.color.value == 0xFF + assert repr(pixel.color) == "" -def test_enum_write(compiled): +def test_enum_struct_write(cs: cstruct, compiled: bool) -> None: cdef = """ enum Test16 : uint16 { A = 0x1, @@ -206,7 +288,6 @@ def test_enum_write(compiled): Test16 list[2]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert verify_compiled(cs.test, compiled) @@ -223,7 +304,7 @@ def test_enum_write(compiled): assert obj.dumps() == b"\x01\x00\x02\x00\x01\x00\x00\x02\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x02\x00" -def test_enum_anonymous(compiled): +def test_enum_anonymous(cs: cstruct, compiled: bool) -> None: cdef = """ enum : uint16 { RED = 1, @@ -231,7 +312,6 @@ def test_enum_anonymous(compiled): BLUE = 3, }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert cs.RED == 1 @@ -241,9 +321,10 @@ def test_enum_anonymous(compiled): assert cs.RED.name == "RED" assert cs.RED.value == 1 assert repr(cs.RED) == "" + assert str(cs.RED) == "RED" -def test_enum_anonymous_struct(compiled): +def test_enum_anonymous_struct(cs: cstruct, compiled: bool) -> None: cdef = """ enum : uint32 { nElements = 4 @@ -253,7 +334,6 @@ def test_enum_anonymous_struct(compiled): uint32 arr[nElements]; }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) test = cs.test @@ -262,7 +342,7 @@ def test_enum_anonymous_struct(compiled): assert t.arr == [255, 0, 0, 10] -def test_enum_reference_own_member(compiled): +def test_enum_reference_own_member(cs: cstruct, compiled: bool) -> None: cdef = """ enum test { A, @@ -270,7 +350,6 @@ def test_enum_reference_own_member(compiled): C }; """ - cs = cstruct.cstruct() cs.load(cdef, compiled=compiled) assert cs.test.A == 0 diff --git a/tests/test_types_flag.py b/tests/test_types_flag.py new file mode 100644 index 0000000..a326225 --- /dev/null +++ b/tests/test_types_flag.py @@ -0,0 +1,235 @@ +from enum import Flag as StdFlag + +import pytest + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.types.enum import PY_311 +from dissect.cstruct.types.flag import Flag + +from .utils import verify_compiled + + +@pytest.fixture +def TestFlag(cs: cstruct) -> type[Flag]: + return cs._make_flag("Test", cs.uint8, {"A": 1, "B": 2}) + + +def test_flag(cs: cstruct, TestFlag: type[Flag]) -> None: + assert issubclass(TestFlag, StdFlag) + assert TestFlag.cs is cs + assert TestFlag.type is cs.uint8 + assert TestFlag.size == 1 + assert TestFlag.alignment == 1 + + assert TestFlag.A == 1 + assert TestFlag.B == 2 + assert TestFlag(1) == TestFlag.A + assert TestFlag(2) == TestFlag.B + + assert TestFlag(0) == 0 + assert TestFlag(0).name is None + assert TestFlag(0).value == 0 + + +def test_flag_read(TestFlag: type[Flag]) -> None: + assert TestFlag(b"\x02") == TestFlag.B + + +def test_flag_write(TestFlag: type[Flag]) -> None: + assert TestFlag.A.dumps() == b"\x01" + assert TestFlag(b"\x02").dumps() == b"\x02" + + +def test_flag_array_read(TestFlag: type[Flag]) -> None: + assert TestFlag[2](b"\x02\x01") == [TestFlag.B, TestFlag.A] + assert TestFlag[None](b"\x02\x01\x00") == [TestFlag.B, TestFlag.A] + + +def test_flag_array_write(TestFlag: type[Flag]) -> None: + assert TestFlag[2]([TestFlag.B, TestFlag.A]).dumps() == b"\x02\x01" + assert TestFlag[None]([TestFlag.B, TestFlag.A]).dumps() == b"\x02\x01\x00" + + +def test_flag_operator(TestFlag: type[Flag]) -> None: + assert TestFlag.A | TestFlag.B == 3 + assert TestFlag(3) == TestFlag.A | TestFlag.B + assert isinstance(TestFlag.A | TestFlag.B, TestFlag) + + assert TestFlag(b"\x03") == TestFlag.A | TestFlag.B + assert TestFlag[2](b"\x02\x03") == [TestFlag.B, (TestFlag.A | TestFlag.B)] + + assert (TestFlag.A | TestFlag.B).dumps() == b"\x03" + assert TestFlag[2]([TestFlag.B, (TestFlag.A | TestFlag.B)]).dumps() == b"\x02\x03" + + +def test_flag_struct(cs: cstruct) -> None: + cdef = """ + flag Test { + a, + b, + c, + d + }; + + flag Odd { + a = 2, + b, + c, + d = 32, e, f, + g + }; + """ + cs.load(cdef) + + assert cs.Test.a == 1 + assert cs.Test.b == 2 + assert cs.Test.c == 4 + assert cs.Test.d == 8 + + assert cs.Odd.a == 2 + assert cs.Odd.b == 4 + assert cs.Odd.c == 8 + assert cs.Odd.d == 32 + assert cs.Odd.e == 64 + assert cs.Odd.f == 128 + assert cs.Odd.g == 256 + + assert cs.Test.a == cs.Test.a + assert cs.Test.a != cs.Test.b + assert cs.Test.b != cs.Odd.a + assert bool(cs.Test(0)) is False + assert bool(cs.Test(1)) is True + + assert cs.Test.a | cs.Test.b == 3 + if PY_311: + assert repr(cs.Test.c | cs.Test.d) == "" + assert str(cs.Test.c | cs.Test.d) == "Test.c|d" + assert repr(cs.Test.a | cs.Test.b) == "" + assert str(cs.Test.a | cs.Test.b) == "Test.a|b" + assert repr(cs.Test(69)) == "" + assert str(cs.Test(69)) == "Test.a|c|64" + else: + assert repr(cs.Test.c | cs.Test.d) == "" + assert str(cs.Test.c | cs.Test.d) == "Test.d|c" + assert repr(cs.Test.a | cs.Test.b) == "" + assert str(cs.Test.a | cs.Test.b) == "Test.b|a" + assert repr(cs.Test(69)) == "" + assert str(cs.Test(69)) == "Test.64|c|a" + assert cs.Test(2) == cs.Test.b + assert cs.Test(3) == cs.Test.a | cs.Test.b + assert cs.Test.c & 12 == cs.Test.c + assert cs.Test.b & 12 == 0 + assert cs.Test.b ^ cs.Test.a == cs.Test.a | cs.Test.b + + # TODO: determine if we want to stay true to Python stdlib or a consistent behaviour + if PY_311: + assert ~cs.Test.a == 14 + assert repr(~cs.Test.a) == "" + else: + assert ~cs.Test.a == -2 + assert repr(~cs.Test.a) == "" + + +def test_flag_struct_read(cs: cstruct, compiled: bool) -> None: + cdef = """ + flag Test16 : uint16 { + A = 0x1, + B = 0x2 + }; + + flag Test24 : uint24 { + A = 0x1, + B = 0x2 + }; + + flag Test32 : uint32 { + A = 0x1, + B = 0x2 + }; + + struct test { + Test16 a16; + Test16 b16; + Test24 a24; + Test24 b24; + Test32 a32; + Test32 b32; + Test16 l[2]; + Test16 c16; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"\x01\x00\x02\x00\x01\x00\x00\x02\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x01\x00\x02\x00\x03\x00" + obj = cs.test(buf) + + assert isinstance(obj.a16, cs.Test16) and obj.a16.value == cs.Test16.A + assert isinstance(obj.b16, cs.Test16) and obj.b16.value == cs.Test16.B + assert isinstance(obj.a24, cs.Test24) and obj.a24.value == cs.Test24.A + assert isinstance(obj.b24, cs.Test24) and obj.b24.value == cs.Test24.B + assert isinstance(obj.a32, cs.Test32) and obj.a32.value == cs.Test32.A + assert isinstance(obj.b32, cs.Test32) and obj.b32.value == cs.Test32.B + + assert len(obj.l) == 2 + assert isinstance(obj.l[0], cs.Test16) and obj.l[0].value == cs.Test16.A + assert isinstance(obj.l[1], cs.Test16) and obj.l[1].value == cs.Test16.B + + assert obj.c16 == cs.Test16.A | cs.Test16.B + assert obj.c16 & cs.Test16.A + if PY_311: + assert repr(obj.c16) == "" + else: + assert repr(obj.c16) == "" + + assert obj.dumps() == buf + + +def test_flag_anonymous(cs: cstruct, compiled: bool) -> None: + cdef = """ + flag : uint16 { + A, + B, + C, + }; + """ + cs.load(cdef, compiled=compiled) + + assert cs.A == 1 + assert cs.B == 2 + assert cs.C == 4 + + assert cs.A.name == "A" + assert cs.A.value == 1 + assert repr(cs.A) == "" + assert str(cs.A) == "A" + + if PY_311: + assert repr(cs.A | cs.B) == "" + assert str(cs.A | cs.B) == "A|B" + assert repr(cs.A.__class__(69)) == "" + assert str(cs.A.__class__(69)) == "A|C|64" + else: + assert repr(cs.A | cs.B) == "" + assert str(cs.A | cs.B) == "B|A" + assert repr(cs.A.__class__(69)) == "<64|C|A: 69>" + assert str(cs.A.__class__(69)) == "64|C|A" + + +def test_flag_anonymous_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ + flag : uint32 { + nElements = 4 + }; + + struct test { + uint32 arr[nElements]; + }; + """ + cs.load(cdef, compiled=compiled) + + test = cs.test + + t = test(b"\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0A\x00\x00\x00") + assert t.arr == [255, 0, 0, 10] diff --git a/tests/test_types_int.py b/tests/test_types_int.py new file mode 100644 index 0000000..056d22e --- /dev/null +++ b/tests/test_types_int.py @@ -0,0 +1,402 @@ +import pytest + +from dissect.cstruct.cstruct import cstruct + +from .utils import verify_compiled + + +def test_int_unsigned_read(cs: cstruct) -> None: + assert cs.uint24(b"AAA") == 0x414141 + assert cs.uint24(b"\xff\xff\xff") == 0xFFFFFF + + assert cs.uint48(b"AAAAAA") == 0x414141414141 + assert cs.uint48(b"\xff\xff\xff\xff\xff\xff") == 0xFFFFFFFFFFFF + + assert cs.uint128(b"A" * 16) == 0x41414141414141414141414141414141 + assert cs.uint128(b"\xff" * 16) == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40(b"AAAAA") == 0x4141414141 + assert uint40(b"\xff\xff\xff\xff\xff") == 0xFFFFFFFFFF + + +def test_int_unsigned_write(cs: cstruct) -> None: + assert cs.uint24(0x414141).dumps() == b"AAA" + assert cs.uint24(0xFFFFFF).dumps() == b"\xff\xff\xff" + + assert cs.uint48(0x414141414141).dumps() == b"AAAAAA" + assert cs.uint48(0xFFFFFFFFFFFF).dumps() == b"\xff\xff\xff\xff\xff\xff" + + assert cs.uint128(0x41414141414141414141414141414141).dumps() == b"A" * 16 + assert cs.uint128(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF).dumps() == b"\xff" * 16 + + assert cs.uint128(b"A" * 16).dumps() == b"A" * 16 + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40(0x4141414141).dumps() == b"AAAAA" + assert uint40(0xFFFFFFFFFF).dumps() == b"\xff\xff\xff\xff\xff" + + +def test_int_unsigned_array_read(cs: cstruct) -> None: + assert cs.uint24[4](b"AAABBBCCCDDD") == [0x414141, 0x424242, 0x434343, 0x444444] + + assert cs.uint48[4](b"AAAAAABBBBBBCCCCCCDDDDDD") == [0x414141414141, 0x424242424242, 0x434343434343, 0x444444444444] + + assert cs.uint128[2](b"A" * 16 + b"B" * 16) == [ + 0x41414141414141414141414141414141, + 0x42424242424242424242424242424242, + ] + assert cs.uint128[None](b"AAAAAAAAAAAAAAAA" + (b"\x00" * 16)) == [0x41414141414141414141414141414141] + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40[2](b"AAAAABBBBB") == [0x4141414141, 0x4242424242] + assert uint40[None](b"AAAAA" + (b"\x00" * 5)) == [0x4141414141] + + +def test_int_unsigned_array_write(cs: cstruct) -> None: + assert cs.uint24[4]([0x414141, 0x424242, 0x434343, 0x444444]).dumps() == b"AAABBBCCCDDD" + + assert ( + cs.uint48[4]([0x414141414141, 0x424242424242, 0x434343434343, 0x444444444444]).dumps() + == b"AAAAAABBBBBBCCCCCCDDDDDD" + ) + + assert ( + cs.uint128[2]( + [ + 0x41414141414141414141414141414141, + 0x42424242424242424242424242424242, + ] + ).dumps() + == b"A" * 16 + b"B" * 16 + ) + assert cs.uint128[None]([0x41414141414141414141414141414141]).dumps() == b"AAAAAAAAAAAAAAAA" + (b"\x00" * 16) + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40[2]([0x4141414141, 0x4242424242]).dumps() == b"AAAAABBBBB" + assert uint40[None]([0x4141414141]).dumps() == b"AAAAA" + (b"\x00" * 5) + + +def test_int_signed_read(cs: cstruct) -> None: + assert cs.int24(b"\xff\x00\x00") == 255 + assert cs.int24(b"\xff\xff\xff") == -1 + + int40 = cs._make_int_type("int40", 5, True) + assert int40(b"AAAAA") == 0x4141414141 + assert int40(b"\xff\xff\xff\xff\xff") == -1 + + +def test_int_signed_write(cs: cstruct) -> None: + assert cs.int24(255).dumps() == b"\xff\x00\x00" + assert cs.int24(-1).dumps() == b"\xff\xff\xff" + + assert cs.int128(0x41414141414141414141414141414141).dumps() == b"A" * 16 + assert cs.int128(-1).dumps() == b"\xff" * 16 + + assert cs.int128(b"A" * 16).dumps() == b"A" * 16 + + int40 = cs._make_int_type("int40", 5, True) + assert int40(0x4141414141).dumps() == b"AAAAA" + assert int40(-1).dumps() == b"\xff\xff\xff\xff\xff" + + +def test_int_signed_array_read(cs: cstruct) -> None: + assert cs.int24[4](b"\xff\xff\xff\xfe\xff\xff\xfd\xff\xff\xfc\xff\xff") == [-1, -2, -3, -4] + + assert cs.int128[2](b"\xff" * 16 + b"\xfe" + b"\xff" * 15) == [-1, -2] + + int40 = cs._make_int_type("int40", 5, True) + assert int40[2](b"\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xff") == [-1, -2] + + +def test_int_signed_array_write(cs: cstruct) -> None: + assert cs.int24[4]([-1, -2, -3, -4]).dumps() == b"\xff\xff\xff\xfe\xff\xff\xfd\xff\xff\xfc\xff\xff" + assert cs.int24[None]([-1]).dumps() == b"\xff\xff\xff\x00\x00\x00" + + assert cs.int128[2]([-1, -2]).dumps() == b"\xff" * 16 + b"\xfe" + b"\xff" * 15 + + int40 = cs._make_int_type("int40", 5, True) + assert int40[2]([-1, -2]).dumps() == b"\xff\xff\xff\xff\xff\xfe\xff\xff\xff\xff" + + +def test_int_unsigned_be_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint24(b"\x00\x00\xff") == 255 + assert cs.uint24(b"\xff\xff\xff") == 0xFFFFFF + + assert cs.uint128(b"\x00" * 15 + b"\xff") == 255 + assert cs.uint128(b"\xff" * 16) == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40(b"\x00\x00\x00\x00\xff") == 255 + assert uint40(b"\xff\xff\xff\xff\xff") == 0xFFFFFFFFFF + + +def test_int_unsigned_be_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint24(255).dumps() == b"\x00\x00\xff" + assert cs.uint24(0xFFFFFF).dumps() == b"\xff\xff\xff" + + assert cs.uint128(255).dumps() == b"\x00" * 15 + b"\xff" + assert cs.uint128(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF).dumps() == b"\xff" * 16 + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40(255).dumps() == b"\x00\x00\x00\x00\xff" + assert uint40(0xFFFFFFFFFF).dumps() == b"\xff\xff\xff\xff\xff" + + +def test_int_unsigned_be_array_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint24[3](b"\x00\x00\xff\x00\x00\xfe\x00\x00\xfd") == [255, 254, 253] + + assert cs.uint24[None](b"\x00\x00\xff\x00\x00\x00") == [255] + + assert cs.uint128[2](b"\x00" * 15 + b"A" + b"\x00" * 15 + b"B") == [0x41, 0x42] + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40[2](b"\x00\x00\x00\x00A\x00\x00\x00\x00B") == [0x41, 0x42] + + +def test_int_unsigned_be_array_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint24[3]([255, 254, 253]).dumps() == b"\x00\x00\xff\x00\x00\xfe\x00\x00\xfd" + assert cs.uint24[None]([255]).dumps() == b"\x00\x00\xff\x00\x00\x00" + + assert cs.uint128[2]([0x41, 0x42]).dumps() == b"\x00" * 15 + b"A" + b"\x00" * 15 + b"B" + + uint40 = cs._make_int_type("uint40", 5, False) + assert uint40[2]([0x41, 0x42]).dumps() == b"\x00\x00\x00\x00A\x00\x00\x00\x00B" + + +def test_int_signed_be_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.int24(b"\x00\x00\xff") == 255 + assert cs.int24(b"\xff\xff\x01") == -255 + + int40 = cs._make_int_type("int40", 5, True) + assert int40(b"\x00\x00\x00\x00\xff") == 255 + assert int40(b"\xff\xff\xff\xff\xff") == -1 + assert int40(b"\xff\xff\xff\xff\x01") == -255 + + +def test_int_signed_be_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.int24(255).dumps() == b"\x00\x00\xff" + assert cs.int24(-255).dumps() == b"\xff\xff\x01" + + assert cs.int128(255).dumps() == b"\x00" * 15 + b"\xff" + assert cs.int128(-1).dumps() == b"\xff" * 16 + assert cs.int128(-255).dumps() == b"\xff" * 15 + b"\x01" + + int40 = cs._make_int_type("int40", 5, True) + assert int40(255).dumps() == b"\x00\x00\x00\x00\xff" + assert int40(-1).dumps() == b"\xff\xff\xff\xff\xff" + assert int40(-255).dumps() == b"\xff\xff\xff\xff\x01" + + +def test_int_signed_be_array_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.int24[3](b"\xff\xff\x01\xff\xff\x02\xff\xff\x03") == [-255, -254, -253] + + assert cs.int128[2](b"\xff" * 16 + b"\xff" * 15 + b"\xfe") == [-1, -2] + + int40 = cs._make_int_type("int40", 5, True) + assert int40[2](b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe") == [-1, -2] + + +def test_int_signed_be_array_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.int24[3]([-255, -254, -253]).dumps() == b"\xff\xff\x01\xff\xff\x02\xff\xff\x03" + + assert cs.int128[2]([-1, -2]).dumps() == b"\xff" * 16 + b"\xff" * 15 + b"\xfe" + + int40 = cs._make_int_type("int40", 5, True) + assert int40[2]([-1, -2]).dumps() == b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xfe" + + +def test_int_eof(cs: cstruct) -> None: + with pytest.raises(EOFError): + cs.int24(b"\x00") + + with pytest.raises(EOFError): + cs.int24[2](b"\x00\x00\x00") + + with pytest.raises(EOFError): + cs.int24[None](b"\x01\x00\x00") + + +def test_bytesinteger_range(cs: cstruct) -> None: + int8 = cs._make_int_type("int8", 1, True) + uint8 = cs._make_int_type("uint9", 1, False) + int16 = cs._make_int_type("int16", 2, True) + int24 = cs._make_int_type("int24", 3, True) + int128 = cs._make_int_type("int128", 16, True) + + int8(127).dumps() + int8(-128).dumps() + uint8(255).dumps() + uint8(0).dumps() + int16(-32768).dumps() + int16(32767).dumps() + int24(-8388608).dumps() + int24(8388607).dumps() + int128(-(2**127) + 1).dumps() + int128(2**127 - 1).dumps() + with pytest.raises(OverflowError): + int8(-129).dumps() + with pytest.raises(OverflowError): + int8(128).dumps() + with pytest.raises(OverflowError): + uint8(-1).dumps() + with pytest.raises(OverflowError): + uint8(256).dumps() + with pytest.raises(OverflowError): + int16(-32769).dumps() + with pytest.raises(OverflowError): + int16(32768).dumps() + with pytest.raises(OverflowError): + int24(-8388609).dumps() + with pytest.raises(OverflowError): + int24(8388608).dumps() + with pytest.raises(OverflowError): + int128(-(2**127) - 1).dumps() + with pytest.raises(OverflowError): + int128(2**127).dumps() + + +def test_int_struct_signed(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + int24 a; + int24 b[2]; + int24 len; + int24 dync[len]; + int24 c; + int24 d[3]; + int128 e; + int128 f[2]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"AAABBBCCC\x02\x00\x00DDDEEE\xff\xff\xff\x01\xff\xff\x02\xff\xff\x03\xff\xff" + buf += b"A" * 16 + buf += b"\xff" * 16 + b"\x01" + b"\xff" * 15 + obj = cs.test(buf) + + assert obj.a == 0x414141 + assert obj.b == [0x424242, 0x434343] + assert obj.len == 0x02 + assert obj.dync == [0x444444, 0x454545] + assert obj.c == -1 + assert obj.d == [-255, -254, -253] + assert obj.e == 0x41414141414141414141414141414141 + assert obj.f == [-1, -255] + assert obj.dumps() == buf + + +def test_int_struct_unsigned(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + uint24 a; + uint24 b[2]; + uint24 len; + uint24 dync[len]; + uint24 c; + uint128 d; + uint128 e[2]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"AAABBBCCC\x02\x00\x00DDDEEE\xff\xff\xff" + buf += b"A" * 16 + buf += b"A" + b"\x00" * 15 + b"B" + b"\x00" * 15 + obj = cs.test(buf) + + assert obj.a == 0x414141 + assert obj.b == [0x424242, 0x434343] + assert obj.len == 0x02 + assert obj.dync == [0x444444, 0x454545] + assert obj.c == 0xFFFFFF + assert obj.d == 0x41414141414141414141414141414141 + assert obj.e == [0x41, 0x42] + assert obj.dumps() == buf + + +def test_bytesinteger_struct_signed_be(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + int24 a; + int24 b[2]; + int24 len; + int24 dync[len]; + int24 c; + int24 d[3]; + int128 e; + int128 f[2]; + }; + """ + cs.load(cdef, compiled=compiled) + cs.endian = ">" + + assert verify_compiled(cs.test, compiled) + + buf = b"AAABBBCCC\x00\x00\x02DDDEEE\xff\xff\xff\xff\xff\x01\xff\xff\x02\xff\xff\x03" + buf += b"A" * 16 + buf += b"\x00" * 15 + b"A" + b"\x00" * 15 + b"B" + obj = cs.test(buf) + + assert obj.a == 0x414141 + assert obj.b == [0x424242, 0x434343] + assert obj.len == 0x02 + assert obj.dync == [0x444444, 0x454545] + assert obj.c == -1 + assert obj.d == [-255, -254, -253] + assert obj.e == 0x41414141414141414141414141414141 + assert obj.f == [0x41, 0x42] + assert obj.dumps() == buf + + +def test_bytesinteger_struct_unsigned_be(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + uint24 a; + uint24 b[2]; + uint24 len; + uint24 dync[len]; + uint24 c; + uint128 d; + uint128 e[2]; + }; + """ + cs.load(cdef, compiled=compiled) + cs.endian = ">" + + assert verify_compiled(cs.test, compiled) + + buf = b"AAABBBCCC\x00\x00\x02DDDEEE\xff\xff\xff" + buf += b"\xff" * 16 + buf += b"\x00" * 14 + b"AA" + b"\x00" * 14 + b"BB" + obj = cs.test(buf) + + assert obj.a == 0x414141 + assert obj.b == [0x424242, 0x434343] + assert obj.len == 0x02 + assert obj.dync == [0x444444, 0x454545] + assert obj.c == 0xFFFFFF + assert obj.d == 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF + assert obj.e == [0x4141, 0x4242] + assert obj.dumps() == buf diff --git a/tests/test_types_leb128.py b/tests/test_types_leb128.py index 689078f..d66972a 100644 --- a/tests/test_types_leb128.py +++ b/tests/test_types_leb128.py @@ -1,19 +1,16 @@ import io import pytest -from dissect import cstruct +from dissect.cstruct.cstruct import cstruct -def test_leb128_unsigned_read_EOF(): - cs = cstruct.cstruct() +def test_leb128_unsigned_read_EOF(cs: cstruct) -> None: with pytest.raises(EOFError, match="EOF reached, while final LEB128 byte was not yet read"): cs.uleb128(b"\x8b") -def test_leb128_unsigned_read(): - cs = cstruct.cstruct() - +def test_leb128_unsigned_read(cs: cstruct) -> None: assert cs.uleb128(b"\x02") == 2 assert cs.uleb128(b"\x8b\x25") == 4747 assert cs.uleb128(b"\xc9\x8f\xb0\x06") == 13371337 @@ -22,9 +19,7 @@ def test_leb128_unsigned_read(): assert cs.uleb128(b"\xde\xd6\xcf\x7c") == 261352286 -def test_leb128_signed_read(): - cs = cstruct.cstruct() - +def test_leb128_signed_read(cs: cstruct) -> None: assert cs.ileb128(b"\x02") == 2 assert cs.ileb128(b"\x8b\x25") == 4747 assert cs.ileb128(b"\xc9\x8f\xb0\x06") == 13371337 @@ -33,16 +28,14 @@ def test_leb128_signed_read(): assert cs.ileb128(b"\xde\xd6\xcf\x7c") == -7083170 -@pytest.mark.parametrize("compiled", [True, False]) -def test_leb128_struct_unsigned(compiled): +def test_leb128_struct_unsigned(cs: cstruct) -> None: cdef = """ struct test { uleb128 len; char data[len]; }; """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) + cs.load(cdef) buf = b"\xaf\x18" buf += b"\x41" * 3119 @@ -56,15 +49,13 @@ def test_leb128_struct_unsigned(compiled): assert obj.dumps() == buf -@pytest.mark.parametrize("compiled", [True, False]) -def test_leb128_struct_unsigned_zero(compiled): +def test_leb128_struct_unsigned_zero(cs: cstruct) -> None: cdef = """ struct test { uleb128 numbers[]; }; """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) + cs.load(cdef) buf = b"\xaf\x18\x8b\x25\xc9\x8f\xb0\x06\x00" obj = cs.test(buf) @@ -77,15 +68,13 @@ def test_leb128_struct_unsigned_zero(compiled): assert obj.dumps() == buf -@pytest.mark.parametrize("compiled", [True, False]) -def test_leb128_struct_signed_zero(compiled): +def test_leb128_struct_signed_zero(cs: cstruct) -> None: cdef = """ struct test { ileb128 numbers[]; }; """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) + cs.load(cdef) buf = b"\xaf\x18\xf5\x5a\xde\xd6\xcf\x7c\x00" obj = cs.test(buf) @@ -98,8 +87,7 @@ def test_leb128_struct_signed_zero(compiled): assert obj.dumps() == buf -@pytest.mark.parametrize("compiled", [True, False]) -def test_leb128_nested_struct_unsigned(compiled): +def test_leb128_nested_struct_unsigned(cs: cstruct) -> None: cdef = """ struct entry { uleb128 len; @@ -113,8 +101,7 @@ def test_leb128_nested_struct_unsigned(compiled): entry entries[n_entries]; }; """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) + cs.load(cdef) # Dummy file format specifying 300 entries buf = b"\x08\x54\x65\x73\x74\x66\x69\x6c\x65\xac\x02" @@ -131,8 +118,7 @@ def test_leb128_nested_struct_unsigned(compiled): assert obj.dumps() == buf -@pytest.mark.parametrize("compiled", [True, False]) -def test_leb128_nested_struct_signed(compiled): +def test_leb128_nested_struct_signed(cs: cstruct) -> None: cdef = """ struct entry { ileb128 len; @@ -146,8 +132,7 @@ def test_leb128_nested_struct_signed(compiled): entry entries[n_entries]; }; """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) + cs.load(cdef) # Dummy file format specifying 300 entries buf = b"\x08\x54\x65\x73\x74\x66\x69\x6c\x65\xac\x02" @@ -164,47 +149,43 @@ def test_leb128_nested_struct_signed(compiled): assert obj.dumps() == buf -def test_leb128_unsigned_write(): - cs = cstruct.cstruct() +def test_leb128_unsigned_write(cs: cstruct) -> None: + assert cs.uleb128(2).dumps() == b"\x02" + assert cs.uleb128(4747).dumps() == b"\x8b\x25" + assert cs.uleb128(13371337).dumps() == b"\xc9\x8f\xb0\x06" + assert cs.uleb128(126).dumps() == b"\x7e" + assert cs.uleb128(11637).dumps() == b"\xf5\x5a" + assert cs.uleb128(261352286).dumps() == b"\xde\xd6\xcf\x7c" - assert cs.uleb128.dumps(2) == b"\x02" - assert cs.uleb128.dumps(4747) == b"\x8b\x25" - assert cs.uleb128.dumps(13371337) == b"\xc9\x8f\xb0\x06" - assert cs.uleb128.dumps(126) == b"\x7e" - assert cs.uleb128.dumps(11637) == b"\xf5\x5a" - assert cs.uleb128.dumps(261352286) == b"\xde\xd6\xcf\x7c" + assert cs.uleb128(b"\xde\xd6\xcf\x7c").dumps() == b"\xde\xd6\xcf\x7c" -def test_leb128_signed_write(): - cs = cstruct.cstruct() +def test_leb128_signed_write(cs: cstruct) -> None: + assert cs.ileb128(2).dumps() == b"\x02" + assert cs.ileb128(4747).dumps() == b"\x8b\x25" + assert cs.ileb128(13371337).dumps() == b"\xc9\x8f\xb0\x06" + assert cs.ileb128(-2).dumps() == b"\x7e" + assert cs.ileb128(-4747).dumps() == b"\xf5\x5a" + assert cs.ileb128(-7083170).dumps() == b"\xde\xd6\xcf\x7c" - assert cs.ileb128.dumps(2) == b"\x02" - assert cs.ileb128.dumps(4747) == b"\x8b\x25" - assert cs.ileb128.dumps(13371337) == b"\xc9\x8f\xb0\x06" - assert cs.ileb128.dumps(-2) == b"\x7e" - assert cs.ileb128.dumps(-4747) == b"\xf5\x5a" - assert cs.ileb128.dumps(-7083170) == b"\xde\xd6\xcf\x7c" + assert cs.ileb128(b"\xde\xd6\xcf\x7c").dumps() == b"\xde\xd6\xcf\x7c" -def test_leb128_write_negatives(): - cs = cstruct.cstruct() - +def test_leb128_write_negatives(cs: cstruct) -> None: with pytest.raises(ValueError, match="Attempt to encode a negative integer using unsigned LEB128 encoding"): - cs.uleb128.dumps(-2) - assert cs.ileb128.dumps(-2) == b"\x7e" - + cs.uleb128(-2).dumps() + assert cs.ileb128(-2).dumps() == b"\x7e" -def test_leb128_unsigned_write_amount_written(): - cs = cstruct.cstruct() +def test_leb128_unsigned_write_amount_written(cs: cstruct) -> None: out1 = io.BytesIO() - bytes_written1 = cs.uleb128.write(out1, 2) + bytes_written1 = cs.uleb128(2).write(out1) assert bytes_written1 == out1.tell() out2 = io.BytesIO() - bytes_written2 = cs.uleb128.write(out2, 4747) + bytes_written2 = cs.uleb128(4747).write(out2) assert bytes_written2 == out2.tell() out3 = io.BytesIO() - bytes_written3 = cs.uleb128.write(out3, 13371337) + bytes_written3 = cs.uleb128(13371337).write(out3) assert bytes_written3 == out3.tell() diff --git a/tests/test_types_packed.py b/tests/test_types_packed.py new file mode 100644 index 0000000..9c7d667 --- /dev/null +++ b/tests/test_types_packed.py @@ -0,0 +1,171 @@ +import pytest + +from dissect.cstruct.cstruct import cstruct + +from .utils import verify_compiled + + +def test_packed_read(cs: cstruct) -> None: + assert cs.uint32(b"AAAA") == 0x41414141 + assert cs.uint32(b"\xFF\xFF\xFF\xFF") == 0xFFFFFFFF + + assert cs.int32(b"\xFF\x00\x00\x00") == 255 + assert cs.int32(b"\xFF\xFF\xFF\xFF") == -1 + + assert cs.float16(b"\x00\x3C") == 1.0 + + assert cs.float(b"\x00\x00\x80\x3f") == 1.0 + + assert cs.double(b"\x00\x00\x00\x00\x00\x00\xf0\x3f") == 1.0 + + +def test_packed_write(cs: cstruct) -> None: + assert cs.uint32(0x41414141).dumps() == b"AAAA" + assert cs.uint32(0xFFFFFFFF).dumps() == b"\xFF\xFF\xFF\xFF" + assert cs.uint32(b"AAAA").dumps() == b"AAAA" + + assert cs.int32(255).dumps() == b"\xFF\x00\x00\x00" + assert cs.int32(-1).dumps() == b"\xFF\xFF\xFF\xFF" + + assert cs.float16(1.0).dumps() == b"\x00\x3C" + + assert cs.float(1.0).dumps() == b"\x00\x00\x80\x3f" + + assert cs.double(1.0).dumps() == b"\x00\x00\x00\x00\x00\x00\xf0\x3f" + + +def test_packed_array_read(cs: cstruct) -> None: + assert cs.uint32[2](b"AAAABBBB") == [0x41414141, 0x42424242] + assert cs.uint32[None](b"AAAABBBB\x00\x00\x00\x00") == [0x41414141, 0x42424242] + + assert cs.int32[2](b"\x00\x00\x00\x00\xFF\xFF\xFF\xFF") == [0, -1] + assert cs.int32[None](b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00") == [-1] + + assert cs.float[2](b"\x00\x00\x80\x3f\x00\x00\x00\x40") == [1.0, 2.0] + assert cs.float[None](b"\x00\x00\x80\x3f\x00\x00\x00\x00") == [1.0] + + +def test_packed_array_write(cs: cstruct) -> None: + assert cs.uint32[2]([0x41414141, 0x42424242]).dumps() == b"AAAABBBB" + assert cs.uint32[None]([0x41414141, 0x42424242]).dumps() == b"AAAABBBB\x00\x00\x00\x00" + + assert cs.int32[2]([0, -1]).dumps() == b"\x00\x00\x00\x00\xFF\xFF\xFF\xFF" + assert cs.int32[None]([-1]).dumps() == b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00" + + assert cs.float[2]([1.0, 2.0]).dumps() == b"\x00\x00\x80\x3f\x00\x00\x00\x40" + assert cs.float[None]([1.0]).dumps() == b"\x00\x00\x80\x3f\x00\x00\x00\x00" + + +def test_packed_be_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint32(b"AAA\x00") == 0x41414100 + assert cs.uint32(b"\xFF\xFF\xFF\x00") == 0xFFFFFF00 + + assert cs.int32(b"\x00\x00\x00\xFF") == 255 + assert cs.int32(b"\xFF\xFF\xFF\xFF") == -1 + + assert cs.float16(b"\x3C\x00") == 1.0 + + assert cs.float(b"\x3f\x80\x00\x00") == 1.0 + + assert cs.double(b"\x3f\xf0\x00\x00\x00\x00\x00\x00") == 1.0 + + +def test_packed_be_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint32(0x41414100).dumps() == b"AAA\x00" + assert cs.uint32(0xFFFFFF00).dumps() == b"\xFF\xFF\xFF\x00" + + assert cs.int32(255).dumps() == b"\x00\x00\x00\xFF" + assert cs.int32(-1).dumps() == b"\xFF\xFF\xFF\xFF" + + assert cs.float16(1.0).dumps() == b"\x3C\x00" + + assert cs.float(1.0).dumps() == b"\x3f\x80\x00\x00" + + assert cs.double(1.0).dumps() == b"\x3f\xf0\x00\x00\x00\x00\x00\x00" + + +def test_packed_be_array_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint32[2](b"\x00\x00\x00\x01\x00\x00\x00\x02") == [1, 2] + assert cs.uint32[None](b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00") == [1, 2] + + assert cs.int32[2](b"\x00\x00\x00\x01\xFF\xFF\xFF\xFE") == [1, -2] + assert cs.int32[None](b"\xFF\xFF\xFF\xFE\x00\x00\x00\x00") == [-2] + + assert cs.float[2](b"\x3f\x80\x00\x00\x40\x00\x00\x00") == [1.0, 2.0] + assert cs.float[None](b"\x3f\x80\x00\x00\x00\x00\x00\x00") == [1.0] + + +def test_packed_be_array_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.uint32[2]([1, 2]).dumps() == b"\x00\x00\x00\x01\x00\x00\x00\x02" + assert cs.uint32[None]([1, 2]).dumps() == b"\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00" + + assert cs.int32[2]([1, -2]).dumps() == b"\x00\x00\x00\x01\xFF\xFF\xFF\xFE" + assert cs.int32[None]([-2]).dumps() == b"\xFF\xFF\xFF\xFE\x00\x00\x00\x00" + + assert cs.float[2]([1.0, 2.0]).dumps() == b"\x3f\x80\x00\x00\x40\x00\x00\x00" + assert cs.float[None]([1.0]).dumps() == b"\x3f\x80\x00\x00\x00\x00\x00\x00" + + +def test_packed_eof(cs: cstruct) -> None: + with pytest.raises(EOFError): + cs.uint32(b"\x00") + + with pytest.raises(EOFError): + cs.uint32[2](b"\x00\x00\x00\x00") + + with pytest.raises(EOFError): + cs.uint32[None](b"\x00\x00\x00\x01") + + +def test_packed_range(cs) -> None: + cs.float16(-65519.999999999996).dumps() + cs.float16(65519.999999999996).dumps() + with pytest.raises(OverflowError): + cs.float16(-65519.999999999997).dumps() + with pytest.raises(OverflowError): + cs.float16(65519.999999999997).dumps() + + +def test_packed_float_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + float16 a; + float b; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"69\xb1U$G" + obj = cs.test(buf) + + assert obj.a == 0.6513671875 + assert obj.b == 42069.69140625 + + +def test_packed_float_struct_be(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + float16 a; + float b; + }; + """ + cs.load(cdef, compiled=compiled) + cs.endian = ">" + + assert verify_compiled(cs.test, compiled) + + buf = b"69G$U\xb1" + obj = cs.test(buf) + + assert obj.a == 0.388916015625 + assert obj.b == 42069.69140625 diff --git a/tests/test_types_pointer.py b/tests/test_types_pointer.py new file mode 100644 index 0000000..d5940d4 --- /dev/null +++ b/tests/test_types_pointer.py @@ -0,0 +1,237 @@ +from unittest.mock import patch + +import pytest + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.exceptions import NullPointerDereference +from dissect.cstruct.types.pointer import Pointer + +from .utils import verify_compiled + + +def test_pointer(cs: cstruct) -> None: + cs.pointer = cs.uint8 + + ptr = cs._make_pointer(cs.uint8) + assert ptr.__name__ == "uint8*" + + obj = ptr(b"\x01\xFF") + assert repr(obj) == "" + + assert obj == 1 + assert obj.dumps() == b"\x01" + assert obj.dereference() == 255 + assert str(obj) == "255" + + with pytest.raises(NullPointerDereference): + ptr(0, None).dereference() + + +def test_pointer_char(cs: cstruct) -> None: + cs.pointer = cs.uint8 + + ptr = cs._make_pointer(cs.char) + assert ptr.__name__ == "char*" + + obj = ptr(b"\x02\x00asdf\x00") + assert repr(obj) == "" + + assert obj == 2 + assert obj.dereference() == b"asdf" + assert str(obj) == "b'asdf'" + + +def test_pointer_operator(cs: cstruct) -> None: + cs.pointer = cs.uint8 + + ptr = cs._make_pointer(cs.uint8) + obj = ptr(b"\x01\x00\xFF") + + assert obj == 1 + assert obj.dumps() == b"\x01" + assert obj.dereference() == 0 + + obj += 1 + assert obj == 2 + assert obj.dumps() == b"\x02" + assert obj.dereference() == 255 + + obj -= 2 + assert obj == 0 + + obj += 4 + assert obj == 4 + + obj -= 2 + assert obj == 2 + + obj *= 12 + assert obj == 24 + + obj //= 2 + assert obj == 12 + + obj %= 10 + assert obj == 2 + + obj **= 4 + assert obj == 16 + + obj <<= 1 + assert obj == 32 + + obj >>= 2 + assert obj == 8 + + obj &= 2 + assert obj == 0 + + obj ^= 4 + assert obj == 4 + + obj |= 8 + assert obj == 12 + + +def test_pointer_eof(cs: cstruct) -> None: + cs.pointer = cs.uint8 + + ptr = cs._make_pointer(cs.uint8) + obj = ptr(b"\x01") + + with pytest.raises(EOFError): + obj.dereference() + + +def test_pointer_struct(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct ptrtest { + uint32 *ptr1; + uint32 *ptr2; + }; + """ + cs.pointer = cs.uint16 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.ptrtest, compiled) + assert cs.pointer is cs.uint16 + + buf = b"\x04\x00\x08\x00\x01\x02\x03\x04\x05\x06\x07\x08" + obj = cs.ptrtest(buf) + + assert repr(obj) == " ptr2=>" + + assert obj.ptr1 != 0 + assert obj.ptr2 != 0 + assert obj.ptr1 != obj.ptr2 + assert obj.ptr1 == 4 + assert obj.ptr2 == 8 + assert obj.ptr1.dereference() == 0x04030201 + assert obj.ptr2.dereference() == 0x08070605 + + obj.ptr1 += 2 + obj.ptr2 -= 2 + assert obj.ptr1 == obj.ptr2 + assert obj.ptr1.dereference() == obj.ptr2.dereference() == 0x06050403 + + assert obj.dumps() == b"\x06\x00\x06\x00" + + with pytest.raises(NullPointerDereference): + cs.ptrtest(b"\x00\x00\x00\x00").ptr1.dereference() + + +def test_pointer_struct_pointer(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + wchar wmagic[4]; + uint8 a; + uint16 b; + uint32 c; + char string[]; + wchar wstring[]; + }; + + struct ptrtest { + test *ptr; + }; + """ + cs.pointer = cs.uint16 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + assert verify_compiled(cs.ptrtest, compiled) + assert cs.pointer is cs.uint16 + + buf = b"\x02\x00testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" + obj = cs.ptrtest(buf) + + assert obj.ptr != 0 + + assert obj.ptr.magic == b"test" + assert obj.ptr.wmagic == "test" + assert obj.ptr.a == 0x01 + assert obj.ptr.b == 0x0302 + assert obj.ptr.c == 0x07060504 + assert obj.ptr.string == b"lalala" + assert obj.ptr.wstring == "test" + + assert obj.dumps() == b"\x02\x00" + + with pytest.raises(NullPointerDereference): + cs.ptrtest(b"\x00\x00\x00\x00").ptr.magic + + +def test_pointer_array(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct mainargs { + uint8_t argc; + char *args[4]; + } + """ + cs.pointer = cs.uint16 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.mainargs, compiled) + assert cs.pointer is cs.uint16 + + buf = b"\x02\x09\x00\x16\x00\x00\x00\x00\x00argument one\x00argument two\x00" + obj = cs.mainargs(buf) + + assert obj.argc == 2 + assert obj.args[2] == 0 + assert obj.args[3] == 0 + assert obj.args[0].dereference() == b"argument one" + assert obj.args[1].dereference() == b"argument two" + + +def test_pointer_sys_size() -> None: + with patch("sys.maxsize", 2**64): + cs = cstruct() + assert cs.pointer is cs.uint64 + + with patch("sys.maxsize", 2**32): + cs = cstruct() + assert cs.pointer is cs.uint32 + + cs = cstruct(pointer="uint16") + assert cs.pointer is cs.uint16 + + +def test_pointer_of_pointer(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + uint32 **ptr; + }; + """ + cs.pointer = cs.uint8 + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + obj = cs.test(b"\x01\x02AAAA") + assert isinstance(obj.ptr, Pointer) + assert isinstance(obj.ptr.dereference(), Pointer) + assert obj.ptr == 1 + assert obj.ptr.dereference() == 2 + assert obj.ptr.dereference().dereference() == 0x41414141 diff --git a/tests/test_types_structure.py b/tests/test_types_structure.py new file mode 100644 index 0000000..d592d31 --- /dev/null +++ b/tests/test_types_structure.py @@ -0,0 +1,531 @@ +import inspect +from io import BytesIO +from types import MethodType +from unittest.mock import MagicMock, call, patch + +import pytest + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.exceptions import ParserError +from dissect.cstruct.types.base import Array, BaseType +from dissect.cstruct.types.pointer import Pointer +from dissect.cstruct.types.structure import Field, Structure + +from .utils import verify_compiled + + +@pytest.fixture +def TestStruct(cs: cstruct) -> type[Structure]: + return cs._make_struct( + "TestStruct", + [Field("a", cs.uint32), Field("b", cs.uint32)], + ) + + +def test_structure(TestStruct: type[Structure]) -> None: + assert issubclass(TestStruct, Structure) + assert len(TestStruct.fields) == 2 + assert TestStruct.fields["a"].name == "a" + assert TestStruct.fields["b"].name == "b" + + assert TestStruct.size == 8 + assert TestStruct.alignment == 4 + + spec = inspect.getfullargspec(TestStruct.__init__) + assert spec.args == ["self", "a", "b"] + assert spec.defaults == (None, None) + + obj = TestStruct(1, 2) + assert isinstance(obj, TestStruct) + assert obj.a == 1 + assert obj.b == 2 + assert len(obj) == 8 + + obj = TestStruct(a=1) + assert obj.a == 1 + assert obj.b is None + assert len(obj) == 8 + + +def test_structure_read(TestStruct: type[Structure]) -> None: + obj = TestStruct(b"\x01\x00\x00\x00\x02\x00\x00\x00") + + assert isinstance(obj, TestStruct) + assert obj.a == 1 + assert obj.b == 2 + + +def test_structure_write(TestStruct: type[Structure]) -> None: + buf = b"\x01\x00\x00\x00\x02\x00\x00\x00" + obj = TestStruct(buf) + + assert obj.dumps() == buf + + obj = TestStruct(a=1, b=2) + assert obj.dumps() == buf + assert bytes(obj) == buf + + obj = TestStruct(a=1) + assert obj.dumps() == b"\x01\x00\x00\x00\x00\x00\x00\x00" + + obj = TestStruct() + assert obj.dumps() == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + +def test_structure_array_read(TestStruct: type[Structure]) -> None: + TestStructArray = TestStruct[2] + + assert issubclass(TestStructArray, Array) + assert TestStructArray.num_entries == 2 + assert TestStructArray.type == TestStruct + + buf = b"\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00" + obj = TestStructArray(buf) + + assert isinstance(obj, TestStructArray) + assert len(obj) == 2 + assert obj[0].a == 1 + assert obj[0].b == 2 + assert obj[1].a == 3 + assert obj[1].b == 4 + + assert obj.dumps() == buf + assert obj == [TestStruct(1, 2), TestStruct(3, 4)] + + obj = TestStruct[None](b"\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") + assert obj == [TestStruct(1, 2)] + + +def test_structure_array_write(TestStruct: type[Structure]) -> None: + TestStructArray = TestStruct[2] + + obj = TestStructArray([TestStruct(1, 2), TestStruct(3, 4)]) + + assert len(obj) == 2 + assert obj.dumps() == b"\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00" + + obj = TestStruct[None]([TestStruct(1, 2)]) + assert obj.dumps() == b"\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + +def test_structure_modify(cs: cstruct) -> None: + TestStruct = cs._make_struct("Test", [Field("a", cs.char)]) + + assert len(TestStruct.fields) == len(TestStruct.lookup) == 1 + assert len(TestStruct) == 1 + spec = inspect.getfullargspec(TestStruct.__init__) + assert spec.args == ["self", "a"] + assert spec.defaults == (None,) + + TestStruct.add_field("b", cs.char) + + assert len(TestStruct.fields) == len(TestStruct.lookup) == 2 + assert len(TestStruct) == 2 + spec = inspect.getfullargspec(TestStruct.__init__) + assert spec.args == ["self", "a", "b"] + assert spec.defaults == (None, None) + + with TestStruct.start_update(): + TestStruct.add_field("c", cs.char) + TestStruct.add_field("d", cs.char) + + assert len(TestStruct.fields) == len(TestStruct.lookup) == 4 + assert len(TestStruct) == 4 + spec = inspect.getfullargspec(TestStruct.__init__) + assert spec.args == ["self", "a", "b", "c", "d"] + assert spec.defaults == (None, None, None, None) + + obj = TestStruct(b"abcd") + assert obj.a == b"a" + assert obj.b == b"b" + assert obj.c == b"c" + assert obj.d == b"d" + + +def test_structure_single_byte_field(cs: cstruct) -> None: + TestStruct = cs._make_struct("TestStruct", [Field("a", cs.char)]) + + obj = TestStruct(b"aaaa") + assert obj.a == b"a" + + cs.char._read = MagicMock() + + obj = TestStruct(b"a") + assert obj.a == b"a" + cs.char._read.assert_not_called() + + +def test_structure_same_name_method(cs: cstruct) -> None: + TestStruct = cs._make_struct("TestStruct", [Field("add_field", cs.char)]) + + assert isinstance(TestStruct.add_field, MethodType) + + obj = TestStruct(b"a") + assert obj.add_field == b"a" + + +def test_structure_bool(TestStruct: type[Structure]) -> None: + assert bool(TestStruct(1, 2)) is True + assert bool(TestStruct()) is False + assert bool(TestStruct(0, 0)) is False + + +def test_structure_cmp(TestStruct: type[Structure]) -> None: + assert TestStruct(1, 2) == TestStruct(1, 2) + assert TestStruct(1, 2) != TestStruct(2, 3) + + +def test_structure_repr(TestStruct: type[Structure]) -> None: + obj = TestStruct(1, 2) + assert repr(obj) == f"<{TestStruct.__name__} a=0x1 b=0x2>" + + +def test_structure_eof(TestStruct: type[Structure]) -> None: + with pytest.raises(EOFError): + TestStruct(b"") + + with pytest.raises(EOFError): + TestStruct[2](b"\x01\x00\x00\x00\x02\x00\x00\x00") + + with pytest.raises(EOFError): + TestStruct[None](b"\x01\x00\x00\x00\x02\x00\x00\x00") + + +def test_structure_definitions(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct _test { + uint32 a; + // uint32 comment + uint32 b; + } test, test1; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + assert cs._test == cs.test == cs.test1 + assert cs.test.__name__ == "_test" + assert cs._test.__name__ == "_test" + + assert "a" in cs.test.fields + assert "b" in cs.test.fields + + with pytest.raises(ParserError): + cdef = """ + struct { + uint32 a; + }; + """ + cs.load(cdef) + + +def test_structure_definition_simple(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + wchar wmagic[4]; + uint8 a; + uint16 b; + uint32 c; + char string[]; + wchar wstring[]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" + obj = cs.test(buf) + + assert obj.magic == b"test" + assert obj["magic"] == obj.magic + assert obj.wmagic == "test" + assert obj.a == 0x01 + assert obj.b == 0x0302 + assert obj.c == 0x07060504 + assert obj.string == b"lalala" + assert obj.wstring == "test" + + with pytest.raises(AttributeError): + obj.nope + + assert obj._sizes["magic"] == 4 + assert len(obj) == len(buf) + assert obj.dumps() == buf + + assert repr(obj) + + fh = BytesIO() + obj.write(fh) + assert fh.getvalue() == buf + + +def test_structure_definition_simple_be(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + wchar wmagic[4]; + uint8 a; + uint16 b; + uint32 c; + char string[]; + wchar wstring[]; + }; + """ + cs.endian = ">" + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"test\x00t\x00e\x00s\x00t\x01\x02\x03\x04\x05\x06\x07lalala\x00\x00t\x00e\x00s\x00t\x00\x00" + obj = cs.test(buf) + + assert obj.magic == b"test" + assert obj.wmagic == "test" + assert obj.a == 0x01 + assert obj.b == 0x0203 + assert obj.c == 0x04050607 + assert obj.string == b"lalala" + assert obj.wstring == "test" + assert obj.dumps() == buf + + for name in obj.fields.keys(): + assert isinstance(getattr(obj, name), BaseType) + + +def test_structure_definition_expressions(cs: cstruct, compiled: bool) -> None: + cdef = """ + #define const 1 + struct test { + uint8 flag; + uint8 data_1[(flag & 1) * 4]; + uint8 data_2[flag & (1 << 2)]; + uint8 data_3[const]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + obj = cs.test(b"\x01\x00\x01\x02\x03\xff") + assert obj.flag == 1 + assert obj.data_1 == [0, 1, 2, 3] + assert obj.data_2 == [] + assert obj.data_3 == [255] + + obj = cs.test(b"\x04\x04\x05\x06\x07\xff") + assert obj.flag == 4 + assert obj.data_1 == [] + assert obj.data_2 == [4, 5, 6, 7] + assert obj.data_3 == [255] + + +def test_structure_definition_sizes(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct static { + uint32 test; + }; + + struct dynamic { + uint32 test[]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.static, compiled) + assert verify_compiled(cs.dynamic, compiled) + + assert len(cs.static) == 4 + + cs.static.add_field("another", cs.uint32) + assert len(cs.static) == 8 + cs.static.add_field("atoffset", cs.uint32, offset=12) + assert len(cs.static) == 16 + + obj = cs.static(b"\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00") + assert obj.test == 1 + assert obj.another == 2 + assert obj.atoffset == 3 + + with pytest.raises(TypeError) as excinfo: + len(cs.dynamic) + assert str(excinfo.value) == "Dynamic size" + + +def test_structure_definition_nested(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test_named { + char magic[4]; + struct { + uint32 a; + uint32 b; + } a; + struct { + char c[8]; + } b; + }; + + struct test_anonymous { + char magic[4]; + struct { + uint32 a; + uint32 b; + }; + struct { + char c[8]; + }; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test_named, compiled) + assert verify_compiled(cs.test_anonymous, compiled) + + assert len(cs.test_named) == len(cs.test_anonymous) == 20 + + data = b"zomg\x39\x05\x00\x00\x28\x23\x00\x00deadbeef" + obj = cs.test_named(data) + assert obj.magic == b"zomg" + assert obj.a.a == 1337 + assert obj.a.b == 9000 + assert obj.b.c == b"deadbeef" + assert obj.dumps() == data + + obj = cs.test_anonymous(data) + assert obj.magic == b"zomg" + assert obj.a == 1337 + assert obj.b == 9000 + assert obj.c == b"deadbeef" + assert obj.dumps() == data + + +def test_structure_definition_write(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + wchar wmagic[4]; + uint8 a; + uint16 b; + uint32 c; + char string[]; + wchar wstring[]; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"testt\x00e\x00s\x00t\x00\x01\x02\x03\x04\x05\x06\x07lalala\x00t\x00e\x00s\x00t\x00\x00\x00" + + obj = cs.test() + obj.magic = "test" + obj.wmagic = "test" + obj.a = 0x01 + obj.b = 0x0302 + obj.c = 0x07060504 + obj.string = b"lalala" + obj.wstring = "test" + + with pytest.raises(AttributeError): + obj.nope + + assert obj.dumps() == buf + + inst = cs.test( + magic=b"test", + wmagic="test", + a=0x01, + b=0x0302, + c=0x07060504, + string=b"lalala", + wstring="test", + ) + assert inst.dumps() == buf + + +def test_structure_definition_write_be(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + wchar wmagic[4]; + uint8 a; + uint16 b; + uint32 c; + char string[]; + wchar wstring[]; + }; + """ + cs.endian = ">" + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"test\x00t\x00e\x00s\x00t\x01\x02\x03\x04\x05\x06\x07lalala\x00\x00t\x00e\x00s\x00t\x00\x00" + + obj = cs.test() + obj.magic = "test" + obj.wmagic = "test" + obj.a = 0x01 + obj.b = 0x0203 + obj.c = 0x04050607 + obj.string = b"lalala" + obj.wstring = "test" + + assert obj.dumps() == buf + + +def test_structure_definition_write_anonymous(cs: cstruct) -> None: + cdef = """ + struct test { + uint32 a; + union { + struct { + uint16 b1; + uint16 b2; + }; + uint32 b; + }; + uint32 c; + }; + """ + cs.load(cdef) + + obj = cs.test(a=1, c=3) + assert obj.dumps() == b"\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00" + + +def test_structure_field_discard(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + uint8 a; + uint8 _; + uint16 b; + uint16 _; + uint16 c; + char d; + char _; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + with patch.object(cs.char, "__new__") as mock_char_new: + cs.test(b"\x01\x02\x03\x00\x04\x00\x05\x00ab") + + assert len(mock_char_new.mock_calls) == 2 + mock_char_new.assert_has_calls([call(cs.char, b"a"), call(cs.char, b"b")]) + + +def test_structure_definition_self(cs: cstruct) -> None: + cdef = """ + struct test { + uint32 a; + struct test * b; + }; + """ + cs.load(cdef) + + assert issubclass(cs.test.fields["b"].type, Pointer) + assert cs.test.fields["b"].type.type is cs.test diff --git a/tests/test_types_union.py b/tests/test_types_union.py new file mode 100644 index 0000000..51a40f8 --- /dev/null +++ b/tests/test_types_union.py @@ -0,0 +1,338 @@ +import inspect + +import pytest + +from dissect.cstruct.cstruct import cstruct +from dissect.cstruct.types.base import Array +from dissect.cstruct.types.structure import Field, Union + +from .utils import verify_compiled + + +@pytest.fixture +def TestUnion(cs: cstruct) -> type[Union]: + return cs._make_union( + "TestUnion", + [Field("a", cs.uint32), Field("b", cs.uint16)], + ) + + +def test_union(TestUnion: type[Union]) -> None: + assert issubclass(TestUnion, Union) + assert len(TestUnion.fields) == 2 + assert TestUnion.fields["a"].name == "a" + assert TestUnion.fields["b"].name == "b" + + assert TestUnion.size == 4 + assert TestUnion.alignment == 4 + + spec = inspect.getfullargspec(TestUnion.__init__) + assert spec.args == ["self", "a", "b"] + assert spec.defaults == (None, None) + + obj = TestUnion(1, 2) + assert isinstance(obj, TestUnion) + assert obj.a == 1 + assert obj.b == 2 + assert len(obj) == 4 + + obj = TestUnion(a=1) + assert obj.a == 1 + assert obj.b is None + assert len(obj) == 4 + + +def test_union_read(TestUnion: type[Union]) -> None: + obj = TestUnion(b"\x01\x00\x00\x00") + + assert isinstance(obj, TestUnion) + assert obj.a == 1 + assert obj.b == 1 + + +def test_union_write(TestUnion: type[Union]) -> None: + buf = b"\x01\x00\x00\x00" + obj = TestUnion(buf) + + assert obj.dumps() == buf + + obj = TestUnion(a=1, b=2) + assert obj.dumps() == buf + assert bytes(obj) == buf + + obj = TestUnion(b=1) + assert obj.dumps() == b"\x01\x00\x00\x00" + + obj = TestUnion() + assert obj.dumps() == b"\x00\x00\x00\x00" + + +def test_union_array_read(TestUnion: type[Union]) -> None: + TestUnionArray = TestUnion[2] + + assert issubclass(TestUnionArray, Array) + assert TestUnionArray.num_entries == 2 + assert TestUnionArray.type == TestUnion + + buf = b"\x01\x00\x00\x00\x02\x00\x00\x00" + obj = TestUnionArray(buf) + + assert isinstance(obj, TestUnionArray) + assert len(obj) == 2 + assert obj[0].a == 1 + assert obj[0].b == 1 + assert obj[1].a == 2 + assert obj[1].b == 2 + + assert obj.dumps() == buf + assert obj == [TestUnion(1), TestUnion(2)] + + obj = TestUnion[None](b"\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00") + assert obj == [TestUnion(1), TestUnion(2)] + + +def test_union_array_write(TestUnion: type[Union]) -> None: + TestUnionArray = TestUnion[2] + + obj = TestUnionArray([TestUnion(1), TestUnion(2)]) + + assert len(obj) == 2 + assert obj.dumps() == b"\x01\x00\x00\x00\x02\x00\x00\x00" + + obj = TestUnion[None]([TestUnion(1)]) + assert obj.dumps() == b"\x01\x00\x00\x00\x00\x00\x00\x00" + + +def test_union_modify(cs: cstruct) -> None: + TestUnion = cs._make_union("Test", [Field("a", cs.char)]) + + assert len(TestUnion.fields) == len(TestUnion.lookup) == 1 + assert len(TestUnion) == 1 + spec = inspect.getfullargspec(TestUnion.__init__) + assert spec.args == ["self", "a"] + assert spec.defaults == (None,) + + TestUnion.add_field("b", cs.uint32) + + assert len(TestUnion.fields) == len(TestUnion.lookup) == 2 + assert len(TestUnion) == 4 + spec = inspect.getfullargspec(TestUnion.__init__) + assert spec.args == ["self", "a", "b"] + assert spec.defaults == (None, None) + + with TestUnion.start_update(): + TestUnion.add_field("c", cs.uint16) + TestUnion.add_field("d", cs.uint8) + + assert len(TestUnion.fields) == len(TestUnion.lookup) == 4 + assert len(TestUnion) == 4 + spec = inspect.getfullargspec(TestUnion.__init__) + assert spec.args == ["self", "a", "b", "c", "d"] + assert spec.defaults == (None, None, None, None) + + obj = TestUnion(b"\x01\x02\x03\x04") + assert obj.a == b"\x01" + assert obj.b == 0x04030201 + assert obj.c == 0x0201 + assert obj.d == 0x01 + + +def test_union_bool(TestUnion: type[Union]) -> None: + assert bool(TestUnion(1, 2)) is True + assert bool(TestUnion(1, 1)) is True + assert bool(TestUnion()) is False + assert bool(TestUnion(0, 0)) is False + + +def test_union_cmp(TestUnion: type[Union]) -> None: + assert TestUnion(1) == TestUnion(1) + assert TestUnion(1, 2) == TestUnion(1, 2) + assert TestUnion(1, 2) != TestUnion(2, 3) + assert TestUnion(b=2) == TestUnion(a=2) + + +def test_union_repr(TestUnion: type[Union]) -> None: + obj = TestUnion(1, 2) + assert repr(obj) == f"<{TestUnion.__name__} a=0x1 b=0x2>" + + +def test_union_eof(TestUnion: type[Union]) -> None: + with pytest.raises(EOFError): + TestUnion(b"") + + with pytest.raises(EOFError): + TestUnion[2](b"\x01\x00\x00\x00") + + with pytest.raises(EOFError): + TestUnion[None](b"\x01\x00\x00\x00\x02\x00\x00\x00") + + +def test_union_definition(cs: cstruct) -> None: + cdef = """ + union test { + uint32 a; + char b[8]; + }; + """ + cs.load(cdef, compiled=False) + + assert len(cs.test) == 8 + + buf = b"zomgbeef" + obj = cs.test(buf) + + assert obj.a == 0x676D6F7A + assert obj.b == b"zomgbeef" + + assert obj.dumps() == buf + assert cs.test().dumps() == b"\x00\x00\x00\x00\x00\x00\x00\x00" + + +def test_union_definition_nested(cs: cstruct, compiled: bool) -> None: + cdef = """ + struct test { + char magic[4]; + union { + struct { + uint32 a; + uint32 b; + } a; + struct { + char b[8]; + } b; + } c; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + assert len(cs.test) == 12 + + buf = b"zomgholybeef" + obj = cs.test(buf) + + assert obj.magic == b"zomg" + assert obj.c.a.a == 0x796C6F68 + assert obj.c.a.b == 0x66656562 + assert obj.c.b.b == b"holybeef" + + assert obj.dumps() == buf + + +def test_union_definition_anonymous(cs: cstruct, compiled: bool) -> None: + cdef = """ + typedef struct test + { + union + { + uint32 a; + struct + { + char b[3]; + char c; + }; + }; + uint32 d; + } + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"\x01\x01\x02\x02\x03\x03\x04\x04" + obj = cs.test(buf) + + assert obj.a == 0x02020101 + assert obj.b == b"\x01\x01\x02" + assert obj.c == b"\x02" + assert obj.d == 0x04040303 + assert obj.dumps() == buf + + +def test_union_definition_dynamic(cs: cstruct) -> None: + cdef = """ + struct dynamic { + uint8 size; + char data[size]; + }; + + union test { + dynamic a; + uint64 b; + }; + """ + cs.load(cdef, compiled=False) + + buf = b"\x09aaaaaaaaa" + obj = cs.test(buf) + + assert obj.a.size == 9 + assert obj.a.data == b"aaaaaaaaa" + assert obj.b == 0x6161616161616109 + + +def test_union_update(cs: cstruct) -> None: + cdef = """ + union test { + uint8 a; + uint16 b; + }; + """ + cs.load(cdef) + + obj = cs.test() + obj.a = 1 + assert obj.b == 1 + obj.b = 2 + assert obj.a == 2 + obj.b = 0xFFFF + assert obj.a == 0xFF + assert obj.dumps() == b"\xFF\xFF" + + +def test_union_nested_update(cs: cstruct) -> None: + cdef = """ + struct test { + char magic[4]; + union { + struct { + uint32 a; + uint32 b; + } a; + struct { + char b[8]; + } b; + } c; + }; + """ + cs.load(cdef) + + obj = cs.test() + obj.magic = b"1337" + obj.c.b.b = b"ABCDEFGH" + assert obj.c.a.a == 0x44434241 + assert obj.c.a.b == 0x48474645 + assert obj.dumps() == b"1337ABCDEFGH" + + +def test_union_anonymous_update(cs: cstruct) -> None: + cdef = """ + typedef struct test + { + union { + uint32 a; + struct + { + char b[3]; + char c; + }; + }; + uint32 d; + } + """ + cs.load(cdef) + + obj = cs.test() + obj.a = 0x41414141 + assert obj.b == b"AAA" diff --git a/tests/test_types_void.py b/tests/test_types_void.py new file mode 100644 index 0000000..6393f9b --- /dev/null +++ b/tests/test_types_void.py @@ -0,0 +1,13 @@ +import io + +from dissect.cstruct.cstruct import cstruct + + +def test_void(cs: cstruct) -> None: + assert not cs.void + + stream = io.BytesIO(b"AAAA") + assert not cs.void(stream) + + assert stream.tell() == 0 + assert cs.void().dumps() == b"" diff --git a/tests/test_types_wchar.py b/tests/test_types_wchar.py new file mode 100644 index 0000000..a8086d6 --- /dev/null +++ b/tests/test_types_wchar.py @@ -0,0 +1,77 @@ +import io + +import pytest + +from dissect.cstruct.cstruct import cstruct + + +def test_wchar_read(cs: cstruct) -> None: + buf = b"A\x00A\x00A\x00A\x00\x00\x00" + + assert cs.wchar("A") == "A" + assert cs.wchar(buf) == "A" + assert cs.wchar(io.BytesIO(buf)) == "A" + + +def test_wchar_write(cs: cstruct) -> None: + assert cs.wchar("A").dumps() == b"A\x00" + assert cs.wchar(b"A\x00").dumps() == b"A\x00" + + +def test_wchar_array(cs: cstruct) -> None: + buf = b"A\x00A\x00A\x00A\x00\x00\x00" + + assert cs.wchar[4]("AAAA") == "AAAA" + assert cs.wchar[4](buf) == "AAAA" + assert cs.wchar[4](io.BytesIO(buf)) == "AAAA" + assert cs.wchar[None](io.BytesIO(buf)) == "AAAA" + + +def test_wchar_array_write(cs: cstruct) -> None: + buf = b"A\x00A\x00A\x00A\x00\x00\x00" + + assert cs.wchar[4](buf).dumps() == b"A\x00A\x00A\x00A\x00" + assert cs.wchar[None](buf).dumps() == b"A\x00A\x00A\x00A\x00\x00\x00" + + +def test_wchar_be_read(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.wchar(b"\x00A\x00A\x00A\x00A\x00\x00") == "A" + + +def test_wchar_be_write(cs: cstruct) -> None: + cs.endian = ">" + + assert cs.wchar("A").dumps() == b"\x00A" + + +def test_wchar_be_array(cs: cstruct) -> None: + cs.endian = ">" + + buf = b"\x00A\x00A\x00A\x00A\x00\x00" + + assert cs.wchar[4](buf) == "AAAA" + assert cs.wchar[None](buf) == "AAAA" + + +def test_wchar_be_array_write(cs: cstruct) -> None: + cs.endian = ">" + + buf = b"\x00A\x00A\x00A\x00A\x00\x00" + + assert cs.wchar[4](buf).dumps() == b"\x00A\x00A\x00A\x00A" + assert cs.wchar[None](buf).dumps() == buf + + +def test_wchar_eof(cs: cstruct) -> None: + with pytest.raises(EOFError): + cs.wchar(b"A") + + with pytest.raises(EOFError): + cs.wchar[4](b"") + + with pytest.raises(EOFError): + cs.wchar[None](b"A\x00A\x00A\x00A\x00") + + assert cs.wchar[0](b"") == "" diff --git a/tests/test_union.py b/tests/test_union.py deleted file mode 100644 index fc423f8..0000000 --- a/tests/test_union.py +++ /dev/null @@ -1,89 +0,0 @@ -from dissect import cstruct - -from .utils import verify_compiled - - -def test_union(): - cdef = """ - union test { - uint32 a; - char b[8]; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=False) - - assert len(cs.test) == 8 - - buf = b"zomgbeef" - obj = cs.test(buf) - - assert obj.a == 0x676D6F7A - assert obj.b == b"zomgbeef" - - assert obj.dumps() == buf - assert cs.test().dumps() == b"\x00\x00\x00\x00\x00\x00\x00\x00" - - -def test_union_nested(compiled): - cdef = """ - struct test { - char magic[4]; - union { - struct { - uint32 a; - uint32 b; - } a; - struct { - char b[8]; - } b; - } c; - }; - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - assert len(cs.test) == 12 - - buf = b"zomgholybeef" - obj = cs.test(buf) - - assert obj.magic == b"zomg" - assert obj.c.a.a == 0x796C6F68 - assert obj.c.a.b == 0x66656562 - assert obj.c.b.b == b"holybeef" - - assert obj.dumps() == buf - - -def test_union_anonymous(compiled): - cdef = """ - typedef struct test - { - union - { - uint32 a; - struct - { - char b[3]; - char c; - }; - }; - uint32 d; - } - """ - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"\x01\x01\x02\x02\x03\x03\x04\x04" - obj = cs.test(buf) - - assert obj.a == 0x02020101 - assert obj.b == b"\x01\x01\x02" - assert obj.c == b"\x02" - assert obj.d == 0x04040303 - assert obj.dumps() == buf diff --git a/tests/test_util.py b/tests/test_util.py deleted file mode 100644 index 2ee3b93..0000000 --- a/tests/test_util.py +++ /dev/null @@ -1,90 +0,0 @@ -import pytest -from dissect import cstruct - -from dissect.cstruct.utils import dumpstruct, hexdump - -from .utils import verify_compiled - - -def test_hexdump(capsys): - hexdump(b"\x00" * 16) - captured = capsys.readouterr() - assert captured.out == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................\n" - - out = hexdump(b"\x00" * 16, output="string") - assert out == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................" - - out = hexdump(b"\x00" * 16, output="generator") - assert next(out) == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................" - - with pytest.raises(ValueError) as excinfo: - hexdump("b\x00", output="str") - assert str(excinfo.value) == "Invalid output argument: 'str' (should be 'print', 'generator' or 'string')." - - -def test_dumpstruct(capsys, compiled): - cdef = """ - struct test { - uint32 testval; - }; - """ - - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"\x39\x05\x00\x00" - obj = cs.test(buf) - - dumpstruct(cs.test, buf) - captured_1 = capsys.readouterr() - - dumpstruct(obj) - captured_2 = capsys.readouterr() - - assert captured_1.out == captured_2.out - - out_1 = dumpstruct(cs.test, buf, output="string") - out_2 = dumpstruct(obj, output="string") - - assert out_1 == out_2 - - with pytest.raises(ValueError) as excinfo: - dumpstruct(obj, output="generator") - assert str(excinfo.value) == "Invalid output argument: 'generator' (should be 'print' or 'string')." - - -def test_dumpstruct_anonymous(capsys, compiled): - cdef = """ - struct test { - struct { - uint32 testval; - }; - }; - """ - - cs = cstruct.cstruct() - cs.load(cdef, compiled=compiled) - - assert verify_compiled(cs.test, compiled) - - buf = b"\x39\x05\x00\x00" - obj = cs.test(buf) - - dumpstruct(cs.test, buf) - captured_1 = capsys.readouterr() - - dumpstruct(obj) - captured_2 = capsys.readouterr() - - assert captured_1.out == captured_2.out - - out_1 = dumpstruct(cs.test, buf, output="string") - out_2 = dumpstruct(obj, output="string") - - assert out_1 == out_2 - - with pytest.raises(ValueError) as excinfo: - dumpstruct(obj, output="generator") - assert str(excinfo.value) == "Invalid output argument: 'generator' (should be 'print' or 'string')." diff --git a/tests/test_utils.py b/tests/test_utils.py index 981c908..c733ab9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,92 @@ +import pytest + from dissect.cstruct import utils +from dissect.cstruct.cstruct import cstruct + +from .utils import verify_compiled + + +def test_hexdump(capsys: pytest.CaptureFixture) -> None: + utils.hexdump(b"\x00" * 16) + captured = capsys.readouterr() + assert captured.out == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................\n" + + out = utils.hexdump(b"\x00" * 16, output="string") + assert out == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................" + + out = utils.hexdump(b"\x00" * 16, output="generator") + assert next(out) == "00000000 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 ................" + + with pytest.raises(ValueError) as excinfo: + utils.hexdump("b\x00", output="str") + assert str(excinfo.value) == "Invalid output argument: 'str' (should be 'print', 'generator' or 'string')." + + +def test_dumpstruct(cs: cstruct, capsys: pytest.CaptureFixture, compiled: bool) -> None: + cdef = """ + struct test { + uint32 testval; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"\x39\x05\x00\x00" + obj = cs.test(buf) + + utils.dumpstruct(cs.test, buf) + captured_1 = capsys.readouterr() + + utils.dumpstruct(obj) + captured_2 = capsys.readouterr() + + assert captured_1.out == captured_2.out + + out_1 = utils.dumpstruct(cs.test, buf, output="string") + out_2 = utils.dumpstruct(obj, output="string") + + assert out_1 == out_2 + + with pytest.raises(ValueError) as excinfo: + utils.dumpstruct(obj, output="generator") + assert str(excinfo.value) == "Invalid output argument: 'generator' (should be 'print' or 'string')." + + +def test_dumpstruct_anonymous(cs: cstruct, capsys: pytest.CaptureFixture, compiled: bool) -> None: + cdef = """ + struct test { + struct { + uint32 testval; + }; + }; + """ + cs.load(cdef, compiled=compiled) + + assert verify_compiled(cs.test, compiled) + + buf = b"\x39\x05\x00\x00" + obj = cs.test(buf) + + utils.dumpstruct(cs.test, buf) + captured_1 = capsys.readouterr() + + utils.dumpstruct(obj) + captured_2 = capsys.readouterr() + + assert captured_1.out == captured_2.out + + out_1 = utils.dumpstruct(cs.test, buf, output="string") + out_2 = utils.dumpstruct(obj, output="string") + + assert out_1 == out_2 + + with pytest.raises(ValueError) as excinfo: + utils.dumpstruct(obj, output="generator") + assert str(excinfo.value) == "Invalid output argument: 'generator' (should be 'print' or 'string')." -def test_pack_unpack(): +def test_pack_unpack() -> None: endian = "little" sign = False assert utils.p8(1, endian) == b"\x01" @@ -64,7 +149,7 @@ def test_pack_unpack(): assert utils.unpack(b"^K\xc0\x0c") == 213928798 -def test_swap(): +def test_swap() -> None: assert utils.swap16(0x0001) == 0x0100 assert utils.swap32(0x00000001) == 0x01000000 assert utils.swap64(0x0000000000000001) == 0x0100000000000000 diff --git a/tests/utils.py b/tests/utils.py index 8c252c6..fb9c20b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,2 +1,7 @@ -def verify_compiled(struct, compiled): - return ("+compiled" in repr(struct)) == compiled +from __future__ import annotations + +from dissect.cstruct import Structure + + +def verify_compiled(struct: type[Structure], compiled: bool) -> bool: + return struct.__compiled__ == compiled