Skip to content

Commit

Permalink
Fix recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 28, 2024
1 parent 933a5a6 commit fbb98d6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 4 deletions.
21 changes: 17 additions & 4 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,18 @@ def make_array(
if isinstance(dtype, dtypes.CoreType):
return NotImplemented

fields = {}
fields: dict[str, ndx.Array] = {}

eager_values = None if eager_value is None else dtype._parse_input(eager_value)
for name, field_dtype in dtype._fields().items():
if eager_values is None:
field_value = None
else:
field_value = _assemble_output_recurse(field_dtype, eager_values[name])
fields[name] = field_dtype._ops.make_array(
shape,
field_dtype,
field_dtype._assemble_output(eager_values[name])
if eager_values is not None
else None,
field_value,
)
return ndx.Array._from_fields(
dtype,
Expand All @@ -291,3 +293,14 @@ def getitem(self, x: "Array", index: "IndexType") -> "Array":
index = index._core()

return x._transmute(lambda corearray: corearray[index])


def _assemble_output_recurse(dtype: "Dtype", values: dict) -> np.ndarray:
if isinstance(dtype, dtypes.CoreType):
return dtype._assemble_output(values)
else:
fields = {
name: _assemble_output_recurse(field_dtype, values[name])
for name, field_dtype in dtype._fields().items()
}
return dtype._assemble_output(fields)
24 changes: 24 additions & 0 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,27 @@ def test_create_dtype_mismatched_shape_fields_lazy():
out = x[1:2, 0, ...]

ndx.build({"x": x}, {"out": out})


def test_recursive_construction():
class MyNInt64(StructType):
def _fields(self) -> dict[str, StructType | CoreType]:
return {"x": ndx.nint64}

def _parse_input(self, x: np.ndarray) -> dict:
return {"x": ndx.nint64._parse_input(x)}

def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray:
return fields["x"]

def copy(self) -> Self:
return self

def _schema(self) -> Schema:
return Schema(type_name="my_nint64", author="me")

_ops = UniformShapeOperations()

my_nint64 = MyNInt64()
a = ndx.asarray(np.ma.masked_array([1, 2, 3], [1, 0, 1]), my_nint64)
assert_array_equal(a.to_numpy(), np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64))

0 comments on commit fbb98d6

Please sign in to comment.