Skip to content
This repository has been archived by the owner on Nov 8, 2024. It is now read-only.

Commit

Permalink
IR and start writing
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Oct 8, 2024
1 parent be07108 commit 90b5b1c
Show file tree
Hide file tree
Showing 8 changed files with 304 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/sqomega/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

del importlib

from ._sqw import SQW
from ._sqw import Sqw
from ._bytes import Byteorder
from ._models import SqwFileType, SqwFileHeader, SqwMainHeader

__all__ = ["Byteorder", "SqwMainHeader", "SQW", "SqwFileType", "SqwFileHeader"]
__all__ = ["Byteorder", "SqwMainHeader", "Sqw", "SqwFileType", "SqwFileHeader"]
28 changes: 8 additions & 20 deletions src/sqomega/_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import enum
import sys
from typing import Literal


Expand All @@ -13,36 +14,23 @@ class Byteorder(enum.Enum):

@classmethod
def parse(
cls, value: Byteorder | Literal["little", "big"] | None = None
cls, value: Byteorder | Literal["native", "little", "big"]
) -> Byteorder | None:
if value is None:
return None
if isinstance(value, Byteorder):
return value
if isinstance(value, str):
if value == "native":
return cls.native()
return cls[value]
raise ValueError(f"Invalid Byteorder: {value}")

@classmethod
def native(cls) -> Byteorder:
return cls[sys.byteorder]

def get(self) -> Literal["little", "big"]:
match self:
case Byteorder.little:
return "little"
case Byteorder.big:
return "big"


class TypeTag(enum.Enum):
# Gaps in values are unsupported types.
logical = 0
char = 1
f64 = 3
f32 = 4
i8 = 5
u8 = 6
i32 = 9
u32 = 10
i64 = 11
u64 = 12
cell = 23
struct = 24
serializable = 32 # objects that 'serialize themselves'
127 changes: 127 additions & 0 deletions src/sqomega/_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)

"""Intermediate representation for SQW objects."""

from __future__ import annotations

import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from typing import ClassVar, TypeVar

_T = TypeVar('_T')


class TypeTag(enum.Enum):
"""Single byte tag to identify types in SQW files."""

# Gaps in values are unsupported types.
logical = 0
char = 1
f64 = 3
f32 = 4
i8 = 5
u8 = 6
i32 = 9
u32 = 10
i64 = 11
u64 = 12
cell = 23
struct = 24
serializable = 32 # objects that 'serialize themselves'


@dataclass(kw_only=True)
class ObjectArray:
ty: TypeTag
shape: tuple[int, ...]
data: list[Object]


@dataclass(kw_only=True)
class CellArray:
shape: tuple[int, ...]
data: list[ObjectArray] # nested object array to encode types of each item
ty: ClassVar[TypeTag] = TypeTag.cell


@dataclass(kw_only=True)
class Struct:
field_names: tuple[str, ...]
field_values: CellArray
ty: ClassVar[TypeTag] = TypeTag.struct


@dataclass()
class String:
value: str
ty: ClassVar[TypeTag] = TypeTag.char


@dataclass()
class F64:
value: float
ty: ClassVar[TypeTag] = TypeTag.f64


@dataclass()
class U64:
value: int
ty: ClassVar[TypeTag] = TypeTag.u64


@dataclass()
class U32:
value: int
ty: ClassVar[TypeTag] = TypeTag.u32


@dataclass()
class U8:
value: int
ty: ClassVar[TypeTag] = TypeTag.u8


@dataclass()
class Logical:
value: bool
ty: ClassVar[TypeTag] = TypeTag.logical


# Not supported by SQW but represented here to simplify serialization.
@dataclass()
class Datetime:
value: datetime
ty: ClassVar[TypeTag] = TypeTag.char


Object = Struct | String | F64 | U64 | U32 | U8 | Logical | Datetime


class Serializable(ABC):
@abstractmethod
def _serialize_to_dict(self) -> dict[str, Object]: ...

def serialize_to_ir(self) -> Struct:
fields = self._serialize_to_dict()
return Struct(
field_names=tuple(fields),
field_values=CellArray(
shape=(len(fields), 1), # HORACE uses a 2D array
data=[
ObjectArray(ty=field.ty, shape=(1,), data=[_serialize_field(field)])
for field in fields.values()
],
),
)

def prepare_for_serialization(self: _T) -> _T:
return self


def _serialize_field(field: Object) -> Object:
if isinstance(field, Datetime):
return String(value=field.value.isoformat(timespec='seconds'))
return field
66 changes: 66 additions & 0 deletions src/sqomega/_low_level_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,36 @@ def _add_note_to_read_exception(exc: Exception, sqw_io: LowLevelSqw, ty: str) ->
)


def _annotate_write_exception(
ty: str,
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
"""Add a note with file-information to exceptions from write_* functions."""

def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]:
@functools.wraps(func)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return func(*args, **kwargs)
except (ValueError, UnicodeEncodeError, OverflowError) as exc:
sqw_io: LowLevelSqw = args[0] # type: ignore[assignment]
_add_note_to_read_exception(exc, sqw_io, ty)
raise

return wrapper

return decorator


def _add_note_to_write_exception(exc: Exception, sqw_io: LowLevelSqw, ty: str) -> None:
path_piece = (
"in-memory SQW file" if sqw_io.path is None else f"SQW file '{sqw_io.path}'"
)
_add_note(
exc,
f"When writing a {ty} to {path_piece} at position {sqw_io.position}",
)


def _add_note(exc: Exception, note: str) -> None:
try:
exc.add_note(note) # type: ignore[attr-defined]
Expand Down Expand Up @@ -100,6 +130,42 @@ def read_char_array(self) -> str:
def read_n_chars(self, n: int) -> str:
return self._file.read(n).decode('utf-8')

@_annotate_write_exception("logical")
def write_logical(self, value: bool) -> None:
self._file.write(value.to_bytes(1, self._byteorder.get()))

@_annotate_write_exception("u8")
def write_u8(self, value: int) -> None:
self._file.write(value.to_bytes(1, self._byteorder.get()))

@_annotate_write_exception("u32")
def write_u32(self, value: int) -> None:
self._file.write(value.to_bytes(4, self._byteorder.get()))

@_annotate_write_exception("u64")
def write_u64(self, value: int) -> None:
self._file.write(value.to_bytes(8, self._byteorder.get()))

@_annotate_write_exception("f64")
def write_f64(self, value: float) -> None:
match self._byteorder:
case Byteorder.little:
bo = "<"
case Byteorder.big:
bo = ">"
self._file.write(struct.pack(bo + "d", value))

@_annotate_write_exception("char array")
def write_char_array(self, value: str) -> None:
encoded = value.encode('utf-8')
self.write_u32(len(encoded))
self._file.write(encoded)

@_annotate_write_exception("n chars")
def write_chars(self, value: str) -> None:
encoded = value.encode('utf-8')
self._file.write(encoded)

def seek(self, pos: int) -> None:
self._file.seek(pos)

Expand Down
21 changes: 17 additions & 4 deletions src/sqomega/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)

import enum
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from typing import ClassVar

from . import _ir as ir


class SqwFileType(enum.Enum):
DND = 0
Expand Down Expand Up @@ -36,11 +38,22 @@ class SqwDataBlockDescriptor:


@dataclass(kw_only=True, slots=True)
class SqwMainHeader:
class SqwMainHeader(ir.Serializable):
full_filename: str
title: str
nfiles: int # f64
creation_date: datetime # char_array
nfiles: int
creation_date: datetime

serial_name: ClassVar[str] = "main_header_cl"
version: ClassVar[float] = 2.0

def _serialize_to_dict(self) -> dict[str, ir.Object]:
return {
"serial_name": ir.String(self.serial_name),
"version": ir.F64(self.version),
"full_filename": ir.String(self.full_filename),
"title": ir.String(self.title),
"nfiles": ir.F64(self.nfiles),
"creation_date": ir.Datetime(self.creation_date),
"creation_date_defined_privately": ir.Logical(False),
}
Loading

0 comments on commit 90b5b1c

Please sign in to comment.