Skip to content

Commit

Permalink
Fix serialization for reCirq objects (#337)
Browse files Browse the repository at this point in the history
- Somewhere along the line, the json_serializable_dataclass stoped
  working.
- This is because the functions need to be classmethods or else
  TaskClass._json_namespace_() fails (requiring positional argument obj)
  • Loading branch information
dstrain115 authored Jan 26, 2024
1 parent c3dc8c1 commit 7216fa8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
6 changes: 3 additions & 3 deletions recirq/serialization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def wrap(cls):
unsafe_hash=unsafe_hash,
frozen=frozen)

cls._json_namespace_ = lambda obj: namespace
cls._json_namespace_ = classmethod(lambda obj: namespace)

cls._json_dict_ = lambda obj: cirq.obj_to_dict_helper(
obj, [f.name for f in dataclasses.fields(cls)])
cls._json_dict_ = classmethod(lambda obj: cirq.obj_to_dict_helper(
obj, [f.name for f in dataclasses.fields(cls)]))

if registry is not None:
if namespace is not None:
Expand Down
32 changes: 25 additions & 7 deletions recirq/serialization_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@
import cirq


@recirq.json_serializable_dataclass(
namespace="recirq.test_task", registry=recirq.Registry, frozen=True
)
class TestTask:
dataset_id: str

@property
def fn(self):
return f"{self.dataset_id}"


def test_bits_roundtrip():
bitstring = np.asarray([0, 1, 0, 1, 1, 1, 1, 0, 0, 1])
b = recirq.BitArray(bitstring)
Expand All @@ -36,13 +47,16 @@ def test_bits_roundtrip():

buffer.seek(0)
text = buffer.read()
assert text == """{
assert (
text
== """{
"cirq_type": "recirq.BitArray",
"shape": [
10
],
"packedbits": "5e40"
}"""
)

buffer.seek(0)
b2 = recirq.read_json(buffer)
Expand Down Expand Up @@ -81,11 +95,11 @@ def test_bitstrings_roundtrip_big():
def test_numpy_roundtrip(tmpdir):
re = np.random.uniform(0, 1, 100)
im = np.random.uniform(0, 1, 100)
a = re + 1.j * im
a = re + 1.0j * im
a = np.reshape(a, (10, 10))
ba = recirq.NumpyArray(a)

fn = f'{tmpdir}/hello.json'
fn = f"{tmpdir}/hello.json"
cirq.to_json(ba, fn)
ba2 = recirq.read_json(fn)

Expand All @@ -94,9 +108,13 @@ def test_numpy_roundtrip(tmpdir):

def test_str_and_repr():
bits = np.array([0, 1, 0, 1])
assert str(recirq.BitArray(bits)) == 'recirq.BitArray([0 1 0 1])'
assert repr(recirq.BitArray(bits)) == 'recirq.BitArray(array([0, 1, 0, 1]))'
assert str(recirq.BitArray(bits)) == "recirq.BitArray([0 1 0 1])"
assert repr(recirq.BitArray(bits)) == "recirq.BitArray(array([0, 1, 0, 1]))"

nums = np.array([1, 2, 3, 4])
assert str(recirq.NumpyArray(nums)) == 'recirq.NumpyArray([1 2 3 4])'
assert repr(recirq.NumpyArray(nums)) == 'recirq.NumpyArray(array([1, 2, 3, 4]))'
assert str(recirq.NumpyArray(nums)) == "recirq.NumpyArray([1 2 3 4])"
assert repr(recirq.NumpyArray(nums)) == "recirq.NumpyArray(array([1, 2, 3, 4]))"


def test_json_serializable_class():
assert TestTask._json_namespace_() == "recirq.test_task"

0 comments on commit 7216fa8

Please sign in to comment.