Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Introduce integer constraints #3699

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FloatAttr,
IndexType,
IntegerAttr,
IntegerType,
MemRefType,
ModuleOp,
UnitAttr,
Expand All @@ -31,12 +32,14 @@
from xdsl.irdl import (
AllOf,
AnyAttr,
AnyInt,
AttrSizedOperandSegments,
AttrSizedRegionSegments,
AttrSizedResultSegments,
BaseAttr,
EqAttrConstraint,
GenericAttrConstraint,
IntVarConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
Expand All @@ -48,6 +51,7 @@
VarOperand,
VarOpResult,
attr_def,
eq,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -2473,6 +2477,32 @@ class RangeVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
assert isinstance(my_op, RangeVarOp)


def test_int_var_inference():
@irdl_op_definition
class IntVarOp(IRDLOperation):
name = "test.int_var"
T: ClassVar = IntVarConstraint("T", AnyInt())
ins = var_operand_def(RangeOf(eq(IndexType()), length=T))
outs = var_result_def(RangeOf(eq(IntegerType(64)), length=T))

assembly_format = "$ins attr-dict"

ctx = MLContext()
ctx.load_op(IntVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent("""\
%in0, %in1 = "test.op"() : () -> (index, index)
%out0, %out1 = test.int_var %in0, %in1
""")

parser = Parser(ctx, program)
test_op = parser.parse_optional_operation()
assert isinstance(test_op, test.Operation)
my_op = parser.parse_optional_operation()
assert isinstance(my_op, IntVarOp)
assert my_op.result_types == (IntegerType(64), IntegerType(64))


################################################################################
# Declarative Format Verification #
################################################################################
Expand Down
147 changes: 137 additions & 10 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator, Sequence, Set
from dataclasses import dataclass, field
from dataclasses import KW_ONLY, dataclass, field
from inspect import isclass
from typing import Generic, TypeAlias, TypeVar, cast

Expand Down Expand Up @@ -32,18 +32,27 @@ class ConstraintContext:
_range_variables: dict[str, tuple[Attribute, ...]] = field(default_factory=dict)
"""The assignment of constraint range variables."""

_int_variables: dict[str, int] = field(default_factory=dict)
"""The assignment of constraint int variables."""

def get_variable(self, key: str) -> Attribute | None:
return self._variables.get(key)

def get_range_variable(self, key: str) -> tuple[Attribute, ...] | None:
return self._range_variables.get(key)

def get_int_variable(self, key: str) -> int | None:
return self._int_variables.get(key)

def set_variable(self, key: str, attr: Attribute):
self._variables[key] = attr

def set_range_variable(self, key: str, attr: tuple[Attribute, ...]):
self._range_variables[key] = attr

def set_int_variable(self, key: str, i: int):
self._int_variables[key] = i

@property
def variables(self) -> Sequence[str]:
return tuple(self._variables.keys())
Expand All @@ -52,17 +61,26 @@ def variables(self) -> Sequence[str]:
def range_variables(self) -> Sequence[str]:
return tuple(self._range_variables.keys())

@property
def int_variables(self) -> Sequence[str]:
return tuple(self._int_variables.keys())

def copy(self):
return ConstraintContext(self._variables.copy(), self._range_variables.copy())
return ConstraintContext(
self._variables.copy(),
self._range_variables.copy(),
self._int_variables.copy(),
)

def update(self, other: ConstraintContext):
self._variables.update(other._variables)
self._range_variables.update(other._range_variables)
self._int_variables.update(other._int_variables)


_AttributeCovT = TypeVar("_AttributeCovT", bound=Attribute, covariant=True)

ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute]
ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute] | int
"""
Possible types that a constraint variable can have.
"""
Expand Down Expand Up @@ -651,6 +669,102 @@ def infer(self, context: InferenceContext) -> AttributeCovT:
return self.constr.infer(context)


@dataclass(frozen=True)
class IntConstraint(ABC):
"""Constrain an integer to certain values."""

@abstractmethod
def verify(
self,
i: int,
constraint_context: ConstraintContext,
) -> None:
"""
Check if the integer satisfies the constraint, or raise an exception otherwise.
"""
...

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
"""
Get a dictionary of variables that can be solved from this attribute.
"""
return dict()

def can_infer(self, var_constraint_names: Set[str]) -> bool:
"""
Check if there is enough information to infer the integer given the
constraint variables that are already set.
"""
# By default, we cannot infer anything.
return False

def infer(self, context: InferenceContext) -> int:
"""
Infer the attribute given the the values for all variables.

Raises an exception if the attribute cannot be inferred. If `can_infer`
returns `True` with the given constraint variables, this method should
not raise an exception.
"""
raise ValueError("Cannot infer attribute from constraint")


class AnyInt(IntConstraint):
"""
Constraint that is verified by all integers.
"""

def verify(self, i: int, constraint_context: ConstraintContext) -> None:
pass


@dataclass(frozen=True)
class IntVarConstraint(IntConstraint):
"""
Constrain an integer with the given constraint, and constrain all occurences
of this constraint (i.e, sharing the same name) to be equal.
"""

name: str
"""The variable name. All uses of that name refer to the same variable."""

constraint: IntConstraint
"""The constraint that the variable must satisfy."""

def verify(
self,
i: int,
constraint_context: ConstraintContext,
) -> None:
if self.name in constraint_context.int_variables:
if i != constraint_context.get_int_variable(self.name):
raise VerifyException(
f"integer {constraint_context.get_int_variable(self.name)} expected from int variable "
f"'{self.name}', but got {i}"
)
else:
self.constraint.verify(i, constraint_context)
constraint_context.set_int_variable(self.name, i)

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
return {self.name: IdExtractor()}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.name in var_constraint_names

def infer(
self,
context: InferenceContext,
) -> int:
v = context.variables[self.name]
assert isinstance(v, int)
return v


@dataclass(frozen=True)
class GenericRangeConstraint(Generic[AttributeCovT], ABC):
"""Constrain a range of attributes to certain values."""
Expand All @@ -677,6 +791,14 @@ def get_variable_extractors(
"""
return {}

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
"""
Get a dictionary of variables that can be solved using the length of the range.
"""
return dict()

def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
"""
Check if there is enough information to infer the attribute given the
Expand All @@ -699,10 +821,6 @@ def infer(
"""
raise ValueError("Cannot infer attribute from constraint")

def get_unique_base(self) -> type[Attribute] | None:
"""Get the unique base type that can satisfy the constraint, if any."""
return None


RangeConstraint: TypeAlias = GenericRangeConstraint[Attribute]

Expand Down Expand Up @@ -758,6 +876,8 @@ class RangeOf(GenericRangeConstraint[AttributeCovT]):
"""

constr: GenericAttrConstraint[AttributeCovT]
_: KW_ONLY
length: IntConstraint = field(default_factory=AnyInt)

def verify(
self,
Expand All @@ -766,17 +886,24 @@ def verify(
) -> None:
for a in attrs:
self.constr.verify(a, constraint_context)
self.length.verify(len(attrs), constraint_context)

def get_length_extractors(self) -> dict[str, VarExtractor[int]]:
return self.length.get_length_extractors()

def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
return length_known and self.constr.can_infer(var_constraint_names)
return (
length_known or self.length.can_infer(var_constraint_names)
) and self.constr.can_infer(var_constraint_names)

def infer(
self,
context: InferenceContext,
*,
length: int | None,
) -> Sequence[AttributeCovT]:
assert length is not None
if length is None:
length = self.length.infer(context)
attr = self.constr.infer(context)
return (attr,) * length

Expand Down Expand Up @@ -834,7 +961,7 @@ def range_constr_coercion(
) -> GenericRangeConstraint[AttributeCovT]:
if isinstance(attr, GenericRangeConstraint):
return attr
return RangeOf(attr_constr_coercion(attr))
return RangeOf(attr_constr_coercion(attr), length=AnyInt())


def single_range_constr_coercion(
Expand Down
15 changes: 15 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]):
)
return

@dataclass(frozen=True)
class _OperandLengthResolver(VarExtractor[ParsingState]):
idx: int
inner: VarExtractor[int]

def extract_var(self, a: ParsingState) -> ConstraintVariableType:
assert isinstance(ops := a.operands[self.idx], Sequence)
return self.inner.extract_var(len(ops))

@dataclass(frozen=True)
class _OperandResultExtractor(VarExtractor[ParsingState]):
idx: int
Expand Down Expand Up @@ -275,6 +284,12 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]:
"""
extractor_dicts: list[dict[str, VarExtractor[ParsingState]]] = []
for i, (_, operand_def) in enumerate(self.op_def.operands):
extractor_dicts.append(
{
v: self._OperandLengthResolver(i, r)
for v, r in operand_def.constr.get_length_extractors().items()
}
)
if self.seen_operand_types[i]:
extractor_dicts.append(
{
Expand Down
Loading