From 48605621fea5979a81b6ea4a67032631a544ef61 Mon Sep 17 00:00:00 2001 From: Schamper <1254028+Schamper@users.noreply.github.com> Date: Mon, 4 Sep 2023 11:24:24 +0200 Subject: [PATCH] Address review comments --- dissect/cstruct/compiler.py | 2 +- dissect/cstruct/types/char.py | 2 +- dissect/cstruct/types/enum.py | 1 + dissect/cstruct/types/structure.py | 4 +--- dissect/cstruct/types/wchar.py | 6 +++--- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dissect/cstruct/compiler.py b/dissect/cstruct/compiler.py index d57968e..a56dac8 100644 --- a/dissect/cstruct/compiler.py +++ b/dissect/cstruct/compiler.py @@ -407,7 +407,7 @@ def _generate_struct_info(cs: cstruct, fields: list[Field], align: bool = False) current_offset += size -def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]): +def _optimize_struct_fmt(info: Iterator[tuple[Field, int, str]]) -> str: chars = [] current_count = 0 diff --git a/dissect/cstruct/types/char.py b/dissect/cstruct/types/char.py index c873361..ab84cad 100644 --- a/dissect/cstruct/types/char.py +++ b/dissect/cstruct/types/char.py @@ -15,7 +15,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Char: @classmethod def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Char: if count == 0: - return b"" + return type.__call__(cls, b"") data = stream.read(-1 if count == EOF else count) if count != EOF and len(data) != count: diff --git a/dissect/cstruct/types/enum.py b/dissect/cstruct/types/enum.py index 2ae1fc8..e138837 100644 --- a/dissect/cstruct/types/enum.py +++ b/dissect/cstruct/types/enum.py @@ -72,6 +72,7 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: def _fix_alias_members(cls: type[Enum]): # Emulate aenum NoAlias behaviour + # https://github.com/ethanfurman/aenum/blob/master/aenum/doc/aenum.rst if len(cls._member_names_) == len(cls._member_map_): return diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index 877245e..b4d0fb9 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -457,9 +457,7 @@ class Union(Structure, metaclass=UnionMetaType): """Base class for cstruct union type classes.""" def __eq__(self, other: Any) -> bool: - if self.__class__ is other.__class__: - return bytes(self) == bytes(other) - return False + return self.__class__ is other.__class__ and bytes(self) == bytes(other) def _codegen(func: FunctionType) -> FunctionType: diff --git a/dissect/cstruct/types/wchar.py b/dissect/cstruct/types/wchar.py index 7c285d8..a823e8e 100644 --- a/dissect/cstruct/types/wchar.py +++ b/dissect/cstruct/types/wchar.py @@ -24,7 +24,7 @@ def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: @classmethod def _read_array(cls, stream: BinaryIO, count: int, context: dict[str, Any] = None) -> Wchar: if count == 0: - return "" + return type.__call__(cls, "") if count != EOF: count *= 2 @@ -40,8 +40,8 @@ def _read_0(cls, stream: BinaryIO, context: dict[str, Any] = None) -> Wchar: buf = [] while True: point = stream.read(2) - if len(point) != 2: - raise EOFError("Read 0 bytes, but expected 2") + if (bytes_read := len(point)) != 2: + raise EOFError(f"Read {bytes_read} bytes, but expected 2") if point == b"\x00\x00": break