Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support for List[EIP712Type] #56

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ class Mail(EIP712Message):
```

# Initialize a Person object as you would normally

```python
person = Person(name="Joe", wallet="0xa27CEF8aF2B6575903b676e5644657FAe96F491F")
```
48 changes: 42 additions & 6 deletions eip712/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Message classes for typed structured data hashing and signing in Ethereum.
"""

from typing import Any, Optional
from typing import Any, Optional, List, get_origin, get_args
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use lower-case list now (no import needed).


from dataclassy import asdict, dataclass, fields
from eth_abi.abi import is_encodable_type # type: ignore[import-untyped]
Expand Down Expand Up @@ -50,7 +50,33 @@ def _types_(self) -> dict:

for field in fields(self.__class__):
value = getattr(self, field)
if isinstance(value, EIP712Type):
field_type = self.__annotations__[field]

if get_origin(field_type) is list:
elem_type = get_args(field_type)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this also be a list, like list[list[EIP712Type]]?


if issubclass(elem_type, EIP712Type):
if len(value) > 0:
value0 = value[0]
types[repr(self)].append({"name": field, "type": f"{repr(value0)}[]"})
types.update(value0._types_)
else:
if isinstance(elem_type, str):
if not is_encodable_type(elem_type):
raise ValidationError(f"'{field}: list[{elem_type}]' is not a valid ABI type")

elif issubclass(elem_type, EIP712Type):
elem_type = repr(elem_type)

else:
raise ValidationError(
f"'{field}' type annotation must either be a subclass of "
f"`EIP712Type` or valid ABI Type string, not list[{elem_type.__name__}]"
)

types[repr(self)].append({"name": field, "type": f"{elem_type}[]"})

elif isinstance(value, EIP712Type):
types[repr(self)].append({"name": field, "type": repr(value)})
types.update(value._types_)
else:
Expand Down Expand Up @@ -120,15 +146,18 @@ def _domain_(self) -> dict:
@property
def _body_(self) -> dict:
"""The EIP-712 structured message to be used for serialization and hashing."""

return {
"domain": self._domain_["domain"],
"types": dict(self._types_, **self._domain_["types"]),
"primaryType": repr(self),
"message": {
key: getattr(self, key)
for key in fields(self.__class__)
if not key.startswith("_") or not key.endswith("_")
field: (
[field_elm._body_['message'] for field_elm in getattr(self, field)]
if isinstance(getattr(self, field), list) and not is_encodable_type(self.__annotations__[field])
else getattr(self, field)
)
for field in fields(self.__class__)
if not field.startswith("_") or not field.endswith("_")
},
}

Expand Down Expand Up @@ -167,6 +196,13 @@ def _prepare_data_for_hashing(data: dict) -> dict:
item = asdict(value)
elif isinstance(value, dict):
item = _prepare_data_for_hashing(item)
elif isinstance(value, list):
elms = []
for elm in item:
if isinstance(elm, dict):
elm = _prepare_data_for_hashing(elm)
elms.append(elm)
item = elms

result[key] = item

Expand Down
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from eip712.common import create_permit_def
from eip712.messages import EIP712Message, EIP712Type

from typing import List

PERMIT_NAME = "Yearn Vault"
PERMIT_VERSION = "0.3.5"
PERMIT_CHAIN_ID = 1
Expand Down Expand Up @@ -47,6 +49,40 @@ class InvalidMessageMissingDomainFields(EIP712Message):
value: "uint256" # type: ignore


class NestedType(EIP712Message):
field1: "string" # type: ignore
field2: "uint256" # type: ignore

def __post_init__(self):
self._name_ = "NestedType"
self._version_ = "1"


class MainType(EIP712Message):
name: "string" # type: ignore
age: "uint256" # type: ignore
nested: List[NestedType]

def __post_init__(self):
self._name_ = "MainType"
self._version_ = "1"


@pytest.fixture
def nested_instance_1():
return NestedType(field1="nested1", field2=100)


@pytest.fixture
def nested_instance_2():
return NestedType(field1="nested2", field2=200)


@pytest.fixture
def main_instance(nested_instance_1, nested_instance_2):
return MainType(name="Alice", age=30, nested=[nested_instance_1, nested_instance_2])


@pytest.fixture
def valid_message_with_name_domain_field():
return ValidMessageWithNameDomainField(value=1, sub=SubType(inner=2))
Expand Down
28 changes: 17 additions & 11 deletions tests/test_messages.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import pytest
from eth_account.messages import ValidationError
from eip712.messages import calculate_hash

from .conftest import (
InvalidMessageMissingDomainFields,
MessageWithCanonicalDomainFieldOrder,
MessageWithNonCanonicalDomainFieldOrder,
)

def test_nested_list_message(main_instance):
msg = main_instance.signable_message
assert msg.version.hex() == "0x01"
assert msg.header.hex() == "0x0b2559348f55bf512d3cbed07914b9042c10f07034f553a05e0259103cca9156"
assert msg.body.hex() == "0x0802629e7fba836d4ab3791efd660448d4a23371201ed299e3ffd9bdd6adffaf"

def test_multilevel_message(valid_message_with_name_domain_field):
msg = valid_message_with_name_domain_field.signable_message
assert msg.version.hex() == "01"
assert msg.header.hex() == "336a9d2b32d1ab7ea7bbbd2565eca1910e54b74843858dec7a81f772a3c17e17"
assert msg.body.hex() == "306af87567fa87e55d2bd925d9a3ed2b1ec2c3e71b142785c053dc60b6ca177b"
# Verify hash calculation
message_hash = calculate_hash(msg)
assert message_hash.hex() == "0x2cf8ef0524314a5c218e235d774fb448453b619c124c3bcd66e4b2806291544d"


def test_invalid_message_without_domain_fields():
with pytest.raises(ValidationError):
InvalidMessageMissingDomainFields(value=1)
MainType(age=30, nested=[])


def test_multilevel_message(valid_message_with_name_domain_field):
msg = valid_message_with_name_domain_field.signable_message
assert msg.version.hex() == "0x01"
assert msg.header.hex() == "0x336a9d2b32d1ab7ea7bbbd2565eca1910e54b74843858dec7a81f772a3c17e17"
assert msg.body.hex() == "0x306af87567fa87e55d2bd925d9a3ed2b1ec2c3e71b142785c053dc60b6ca177b"


def test_yearn_vaults_message(permit, permit_raw_data):
Expand Down
Loading