Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: use default-valued type elements for constrs #3821

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 16 additions & 24 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from collections.abc import Iterator, Sequence
from enum import auto
from itertools import product
from typing import Any, ClassVar, Generic, TypeVar, cast
from typing import Any, ClassVar, Generic, cast

from typing_extensions import Self
from typing_extensions import Self, TypeVar

from xdsl.dialects import memref
from xdsl.dialects.builtin import (
Expand All @@ -39,7 +39,6 @@
from xdsl.irdl import (
AnyAttr,
AttrSizedOperandSegments,
BaseAttr,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
Expand All @@ -66,8 +65,9 @@
from xdsl.utils.hints import isa
from xdsl.utils.str_enum import StrEnum

_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True)
_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute)
_StreamTypeElement = TypeVar(
"_StreamTypeElement", bound=Attribute, covariant=True, default=Attribute
)


@irdl_attr_definition
Expand All @@ -87,20 +87,16 @@ def get_element_type(self) -> _StreamTypeElement:
def __init__(self, element_type: _StreamTypeElement):
super().__init__([element_type])

@staticmethod
@classmethod
def constr(
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]](
cls,
element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(),
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]:
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]](
ReadableStreamType, (element_type,)
)


AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]](
ReadableStreamType
)


@irdl_attr_definition
class WritableStreamType(
Generic[_StreamTypeElement],
Expand All @@ -118,20 +114,16 @@ def get_element_type(self) -> _StreamTypeElement:
def __init__(self, element_type: _StreamTypeElement):
super().__init__([element_type])

@staticmethod
@classmethod
def constr(
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]](
cls,
element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(),
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]:
return ParamAttrConstraint[WritableStreamType[_StreamTypeElement]](
WritableStreamType, (element_type,)
)


AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]](
WritableStreamType
)


class IteratorType(StrEnum):
"Iterator type for memref_stream Attribute"

Expand Down Expand Up @@ -471,7 +463,7 @@ class GenericOp(IRDLOperation):
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be read.
"""
outputs = var_operand_def(AnyMemRefTypeConstr | AnyWritableStreamTypeConstr)
outputs = var_operand_def(AnyMemRefTypeConstr | WritableStreamType.constr())
"""
Pointers to memory buffers or streams to be operated on. The corresponding stride
pattern defines the order in which the elements of the input buffers will be written
Expand Down
41 changes: 18 additions & 23 deletions xdsl/dialects/snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from abc import ABC
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
from typing import Generic

from typing_extensions import TypeVar

from xdsl.dialects.builtin import ContainerType, IntAttr
from xdsl.dialects.riscv import IntRegisterType
Expand All @@ -26,7 +28,7 @@
TypeAttribute,
)
from xdsl.irdl import (
BaseAttr,
AnyAttr,
GenericAttrConstraint,
IRDLOperation,
ParamAttrConstraint,
Expand All @@ -39,8 +41,9 @@
)
from xdsl.utils.exceptions import VerifyException

_StreamTypeElement = TypeVar("_StreamTypeElement", bound=Attribute, covariant=True)
_StreamTypeElementConstrT = TypeVar("_StreamTypeElementConstrT", bound=Attribute)
_StreamTypeElement = TypeVar(
"_StreamTypeElement", bound=Attribute, covariant=True, default=Attribute
)


@irdl_attr_definition
Expand All @@ -60,20 +63,16 @@ def get_element_type(self) -> _StreamTypeElement:
def __init__(self, element_type: _StreamTypeElement):
super().__init__([element_type])

@staticmethod
@classmethod
def constr(
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElementConstrT]](
cls,
element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(),
) -> ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]]:
return ParamAttrConstraint[ReadableStreamType[_StreamTypeElement]](
ReadableStreamType, (element_type,)
)


AnyReadableStreamTypeConstr = BaseAttr[ReadableStreamType[Attribute]](
ReadableStreamType
)


@irdl_attr_definition
class WritableStreamType(
Generic[_StreamTypeElement],
Expand All @@ -91,20 +90,16 @@ def get_element_type(self) -> _StreamTypeElement:
def __init__(self, element_type: _StreamTypeElement):
super().__init__([element_type])

@staticmethod
@classmethod
def constr(
element_type: GenericAttrConstraint[_StreamTypeElementConstrT],
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]]:
return ParamAttrConstraint[WritableStreamType[_StreamTypeElementConstrT]](
cls,
element_type: GenericAttrConstraint[_StreamTypeElement] = AnyAttr(),
) -> ParamAttrConstraint[WritableStreamType[_StreamTypeElement]]:
return ParamAttrConstraint[WritableStreamType[_StreamTypeElement]](
WritableStreamType, (element_type,)
)


AnyWritableStreamTypeConstr = BaseAttr[WritableStreamType[Attribute]](
WritableStreamType
)


@dataclass(frozen=True)
class SnitchResources:
"""
Expand Down Expand Up @@ -222,7 +217,7 @@ class SsrEnableOp(IRDLOperation):

name = "snitch.ssr_enable"

streams = var_result_def(AnyReadableStreamTypeConstr | AnyWritableStreamTypeConstr)
streams = var_result_def(ReadableStreamType.constr() | WritableStreamType.constr())

def __init__(self, stream_types: Sequence[Attribute]):
super().__init__(result_types=[stream_types])
Expand Down