Skip to content

Commit

Permalink
core: Introduce integer constraints (#3699)
Browse files Browse the repository at this point in the history
Adds integer variables to the constraint system, allowing the length of a range to be specified.
  • Loading branch information
alexarice authored Jan 24, 2025
1 parent 2d39e73 commit bac3115
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 10 deletions.
30 changes: 30 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
FloatAttr,
IndexType,
IntegerAttr,
IntegerType,
MemRefType,
ModuleOp,
UnitAttr,
Expand All @@ -31,12 +32,14 @@
from xdsl.irdl import (
AllOf,
AnyAttr,
AnyInt,
AttrSizedOperandSegments,
AttrSizedRegionSegments,
AttrSizedResultSegments,
BaseAttr,
EqAttrConstraint,
GenericAttrConstraint,
IntVarConstraint,
IRDLOperation,
ParamAttrConstraint,
ParameterDef,
Expand All @@ -48,6 +51,7 @@
VarOperand,
VarOpResult,
attr_def,
eq,
irdl_attr_definition,
irdl_op_definition,
operand_def,
Expand Down Expand Up @@ -2466,6 +2470,32 @@ class RangeVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
assert isinstance(my_op, RangeVarOp)


def test_int_var_inference():
@irdl_op_definition
class IntVarOp(IRDLOperation):
name = "test.int_var"
T: ClassVar = IntVarConstraint("T", AnyInt())
ins = var_operand_def(RangeOf(eq(IndexType()), length=T))
outs = var_result_def(RangeOf(eq(IntegerType(64)), length=T))

assembly_format = "$ins attr-dict"

ctx = MLContext()
ctx.load_op(IntVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent("""\
%in0, %in1 = "test.op"() : () -> (index, index)
%out0, %out1 = test.int_var %in0, %in1
""")

parser = Parser(ctx, program)
test_op = parser.parse_optional_operation()
assert isinstance(test_op, test.Operation)
my_op = parser.parse_optional_operation()
assert isinstance(my_op, IntVarOp)
assert my_op.result_types == (IntegerType(64), IntegerType(64))


################################################################################
# Declarative Format Verification #
################################################################################
Expand Down
147 changes: 137 additions & 10 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterator, Sequence, Set
from dataclasses import dataclass, field
from dataclasses import KW_ONLY, dataclass, field
from inspect import isclass
from typing import Generic, TypeAlias, TypeVar, cast

Expand Down Expand Up @@ -32,18 +32,27 @@ class ConstraintContext:
_range_variables: dict[str, tuple[Attribute, ...]] = field(default_factory=dict)
"""The assignment of constraint range variables."""

_int_variables: dict[str, int] = field(default_factory=dict)
"""The assignment of constraint int variables."""

def get_variable(self, key: str) -> Attribute | None:
return self._variables.get(key)

def get_range_variable(self, key: str) -> tuple[Attribute, ...] | None:
return self._range_variables.get(key)

def get_int_variable(self, key: str) -> int | None:
return self._int_variables.get(key)

def set_variable(self, key: str, attr: Attribute):
self._variables[key] = attr

def set_range_variable(self, key: str, attr: tuple[Attribute, ...]):
self._range_variables[key] = attr

def set_int_variable(self, key: str, i: int):
self._int_variables[key] = i

@property
def variables(self) -> Sequence[str]:
return tuple(self._variables.keys())
Expand All @@ -52,17 +61,26 @@ def variables(self) -> Sequence[str]:
def range_variables(self) -> Sequence[str]:
return tuple(self._range_variables.keys())

@property
def int_variables(self) -> Sequence[str]:
return tuple(self._int_variables.keys())

def copy(self):
return ConstraintContext(self._variables.copy(), self._range_variables.copy())
return ConstraintContext(
self._variables.copy(),
self._range_variables.copy(),
self._int_variables.copy(),
)

def update(self, other: ConstraintContext):
self._variables.update(other._variables)
self._range_variables.update(other._range_variables)
self._int_variables.update(other._int_variables)


_AttributeCovT = TypeVar("_AttributeCovT", bound=Attribute, covariant=True)

ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute]
ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute] | int
"""
Possible types that a constraint variable can have.
"""
Expand Down Expand Up @@ -651,6 +669,102 @@ def infer(self, context: InferenceContext) -> AttributeCovT:
return self.constr.infer(context)


@dataclass(frozen=True)
class IntConstraint(ABC):
"""Constrain an integer to certain values."""

@abstractmethod
def verify(
self,
i: int,
constraint_context: ConstraintContext,
) -> None:
"""
Check if the integer satisfies the constraint, or raise an exception otherwise.
"""
...

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
"""
Get a dictionary of variables that can be solved from this attribute.
"""
return dict()

def can_infer(self, var_constraint_names: Set[str]) -> bool:
"""
Check if there is enough information to infer the integer given the
constraint variables that are already set.
"""
# By default, we cannot infer anything.
return False

def infer(self, context: InferenceContext) -> int:
"""
Infer the attribute given the the values for all variables.
Raises an exception if the attribute cannot be inferred. If `can_infer`
returns `True` with the given constraint variables, this method should
not raise an exception.
"""
raise ValueError("Cannot infer attribute from constraint")


class AnyInt(IntConstraint):
"""
Constraint that is verified by all integers.
"""

def verify(self, i: int, constraint_context: ConstraintContext) -> None:
pass


@dataclass(frozen=True)
class IntVarConstraint(IntConstraint):
"""
Constrain an integer with the given constraint, and constrain all occurences
of this constraint (i.e, sharing the same name) to be equal.
"""

name: str
"""The variable name. All uses of that name refer to the same variable."""

constraint: IntConstraint
"""The constraint that the variable must satisfy."""

def verify(
self,
i: int,
constraint_context: ConstraintContext,
) -> None:
if self.name in constraint_context.int_variables:
if i != constraint_context.get_int_variable(self.name):
raise VerifyException(
f"integer {constraint_context.get_int_variable(self.name)} expected from int variable "
f"'{self.name}', but got {i}"
)
else:
self.constraint.verify(i, constraint_context)
constraint_context.set_int_variable(self.name, i)

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
return {self.name: IdExtractor()}

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.name in var_constraint_names

def infer(
self,
context: InferenceContext,
) -> int:
v = context.variables[self.name]
assert isinstance(v, int)
return v


@dataclass(frozen=True)
class GenericRangeConstraint(Generic[AttributeCovT], ABC):
"""Constrain a range of attributes to certain values."""
Expand All @@ -677,6 +791,14 @@ def get_variable_extractors(
"""
return {}

def get_length_extractors(
self,
) -> dict[str, VarExtractor[int]]:
"""
Get a dictionary of variables that can be solved using the length of the range.
"""
return dict()

def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
"""
Check if there is enough information to infer the attribute given the
Expand All @@ -699,10 +821,6 @@ def infer(
"""
raise ValueError("Cannot infer attribute from constraint")

def get_unique_base(self) -> type[Attribute] | None:
"""Get the unique base type that can satisfy the constraint, if any."""
return None


RangeConstraint: TypeAlias = GenericRangeConstraint[Attribute]

Expand Down Expand Up @@ -758,6 +876,8 @@ class RangeOf(GenericRangeConstraint[AttributeCovT]):
"""

constr: GenericAttrConstraint[AttributeCovT]
_: KW_ONLY
length: IntConstraint = field(default_factory=AnyInt)

def verify(
self,
Expand All @@ -766,17 +886,24 @@ def verify(
) -> None:
for a in attrs:
self.constr.verify(a, constraint_context)
self.length.verify(len(attrs), constraint_context)

def get_length_extractors(self) -> dict[str, VarExtractor[int]]:
return self.length.get_length_extractors()

def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool:
return length_known and self.constr.can_infer(var_constraint_names)
return (
length_known or self.length.can_infer(var_constraint_names)
) and self.constr.can_infer(var_constraint_names)

def infer(
self,
context: InferenceContext,
*,
length: int | None,
) -> Sequence[AttributeCovT]:
assert length is not None
if length is None:
length = self.length.infer(context)
attr = self.constr.infer(context)
return (attr,) * length

Expand Down Expand Up @@ -834,7 +961,7 @@ def range_constr_coercion(
) -> GenericRangeConstraint[AttributeCovT]:
if isinstance(attr, GenericRangeConstraint):
return attr
return RangeOf(attr_constr_coercion(attr))
return RangeOf(attr_constr_coercion(attr), length=AnyInt())


def single_range_constr_coercion(
Expand Down
15 changes: 15 additions & 0 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]):
)
return

@dataclass(frozen=True)
class _OperandLengthResolver(VarExtractor[ParsingState]):
idx: int
inner: VarExtractor[int]

def extract_var(self, a: ParsingState) -> ConstraintVariableType:
assert isinstance(ops := a.operands[self.idx], Sequence)
return self.inner.extract_var(len(ops))

@dataclass(frozen=True)
class _OperandResultExtractor(VarExtractor[ParsingState]):
idx: int
Expand Down Expand Up @@ -275,6 +284,12 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]:
"""
extractor_dicts: list[dict[str, VarExtractor[ParsingState]]] = []
for i, (_, operand_def) in enumerate(self.op_def.operands):
extractor_dicts.append(
{
v: self._OperandLengthResolver(i, r)
for v, r in operand_def.constr.get_length_extractors().items()
}
)
if self.seen_operand_types[i]:
extractor_dicts.append(
{
Expand Down

0 comments on commit bac3115

Please sign in to comment.