diff --git a/dissect/cstruct/types/base.py b/dissect/cstruct/types/base.py index e55e138..572348c 100644 --- a/dissect/cstruct/types/base.py +++ b/dissect/cstruct/types/base.py @@ -1,7 +1,8 @@ from __future__ import annotations +import functools from io import BytesIO -from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union +from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional, Union from dissect.cstruct.exceptions import ArraySizeError @@ -83,6 +84,31 @@ def read(cls, obj: Union[BinaryIO, bytes]) -> BaseType: return cls._read(obj) + def write(cls, stream: BinaryIO, value: Any) -> int: + """Write a value to a writable file-like object. + + Args: + stream: File-like objects that supports writing. + value: Value to write. + + Returns: + The amount of bytes written. + """ + return cls._write(stream, value) + + def dumps(cls, value: Any) -> bytes: + """Dump a value to a byte string. + + Args: + value: Value to dump. + + Returns: + The raw bytes of this type. + """ + out = BytesIO() + cls._write(out, value) + return out.getvalue() + def _read(cls, stream: BinaryIO, context: dict[str, Any] = None) -> BaseType: """Internal function for reading value. @@ -152,29 +178,33 @@ def _write_0(cls, stream: BinaryIO, array: list[BaseType]) -> int: return cls._write_array(stream, array + [cls()]) -class BaseType(metaclass=MetaType): - """Base class for cstruct type classes.""" +class _overload: + """Descriptor to use on the ``write`` and ``dumps`` methods on cstruct types. - def dumps(self) -> bytes: - """Dump this value to a byte string. + Allows for calling these methods on both the type and instance. - Returns: - The raw bytes of this type. - """ - out = BytesIO() - self.__class__._write(out, self) - return out.getvalue() + Example: + >>> int32.dumps(123) + b'\\x7b\\x00\\x00\\x00' + >>> int32(123).dumps() + b'\\x7b\\x00\\x00\\x00' + """ - def write(self, stream: BinaryIO) -> int: - """Write this value to a writable file-like object. + def __init__(self, func: Callable[[Any], Any]) -> None: + self.func = func - Args: - fh: File-like objects that supports writing. + def __get__(self, instance: Optional[BaseType], owner: MetaType) -> Callable[[Any], bytes]: + if instance is None: + return functools.partial(self.func, owner) + else: + return functools.partial(self.func, instance.__class__, value=instance) - Returns: - The amount of bytes written. - """ - return self.__class__._write(stream, self) + +class BaseType(metaclass=MetaType): + """Base class for cstruct type classes.""" + + dumps = _overload(MetaType.dumps) + write = _overload(MetaType.write) class ArrayMetaType(MetaType): diff --git a/tests/test_basic.py b/tests/test_basic.py index 598f3c3..b86f98b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,4 +1,5 @@ import os +from io import BytesIO import pytest @@ -520,3 +521,13 @@ def test_dynamic_substruct_size(cs: cstruct): assert cs.sub.dynamic assert cs.test.dynamic + + +def test_dumps_write_overload(cs: cstruct): + assert cs.uint8.dumps(1) == cs.uint8(1).dumps() == b"\x01" + + fh = BytesIO() + cs.uint8.write(fh, 1) + assert fh.getvalue() == b"\x01" + cs.uint8(2).write(fh) + assert fh.getvalue() == b"\x01\x02"