From fcf96c03806907f6a4907725051675b30274c269 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 19 Nov 2024 11:18:11 +0000 Subject: [PATCH 1/2] core: Introduce integer constraints --- .../irdl/test_declarative_assembly_format.py | 29 ++++ xdsl/irdl/constraints.py | 145 ++++++++++++++++-- .../declarative_assembly_format_parser.py | 15 ++ 3 files changed, 179 insertions(+), 10 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 5af1dcbb94..cd6e33dd5a 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -17,6 +17,7 @@ FloatAttr, IndexType, IntegerAttr, + IntegerType, MemRefType, ModuleOp, UnitAttr, @@ -66,6 +67,8 @@ var_result_def, var_successor_def, ) +from xdsl.irdl import eq +from xdsl.irdl.constraints import AnyInt, IntVarConstraint from xdsl.irdl.declarative_assembly_format import ( AttrDictDirective, FormatProgram, @@ -2473,6 +2476,32 @@ class RangeVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass] assert isinstance(my_op, RangeVarOp) +def test_int_var_inference(): + @irdl_op_definition + class IntVarOp(IRDLOperation): + name = "test.int_var" + T: ClassVar = IntVarConstraint("T", AnyInt()) + ins = var_operand_def(RangeOf(eq(IndexType()), length=T)) + outs = var_result_def(RangeOf(eq(IntegerType(64)), length=T)) + + assembly_format = "$ins attr-dict" + + ctx = MLContext() + ctx.load_op(IntVarOp) + ctx.load_dialect(Test) + program = textwrap.dedent("""\ + %in0, %in1 = "test.op"() : () -> (index, index) + %out0, %out1 = test.int_var %in0, %in1 + """) + + parser = Parser(ctx, program) + test_op = parser.parse_optional_operation() + assert isinstance(test_op, test.Operation) + my_op = parser.parse_optional_operation() + assert isinstance(my_op, IntVarOp) + assert my_op.result_types == (IntegerType(64), IntegerType(64)) + + ################################################################################ # Declarative Format Verification # ################################################################################ diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 5e14f057f7..861cafdaea 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -3,7 +3,7 @@ import abc from abc import ABC, abstractmethod from collections.abc import Generator, Iterator, Sequence, Set -from dataclasses import dataclass, field +from dataclasses import KW_ONLY, dataclass, field from inspect import isclass from typing import Generic, TypeAlias, TypeVar, cast @@ -32,18 +32,27 @@ class ConstraintContext: _range_variables: dict[str, tuple[Attribute, ...]] = field(default_factory=dict) """The assignment of constraint range variables.""" + _int_variables: dict[str, int] = field(default_factory=dict) + """The assignment of constraint int variables.""" + def get_variable(self, key: str) -> Attribute | None: return self._variables.get(key) def get_range_variable(self, key: str) -> tuple[Attribute, ...] | None: return self._range_variables.get(key) + def get_int_variable(self, key: str) -> int | None: + return self._int_variables.get(key) + def set_variable(self, key: str, attr: Attribute): self._variables[key] = attr def set_range_variable(self, key: str, attr: tuple[Attribute, ...]): self._range_variables[key] = attr + def set_int_variable(self, key: str, i: int): + self._int_variables[key] = i + @property def variables(self) -> Sequence[str]: return tuple(self._variables.keys()) @@ -52,17 +61,26 @@ def variables(self) -> Sequence[str]: def range_variables(self) -> Sequence[str]: return tuple(self._range_variables.keys()) + @property + def int_variables(self) -> Sequence[str]: + return tuple(self._int_variables.keys()) + def copy(self): - return ConstraintContext(self._variables.copy(), self._range_variables.copy()) + return ConstraintContext( + self._variables.copy(), + self._range_variables.copy(), + self._int_variables.copy(), + ) def update(self, other: ConstraintContext): self._variables.update(other._variables) self._range_variables.update(other._range_variables) + self._int_variables.update(other._int_variables) _AttributeCovT = TypeVar("_AttributeCovT", bound=Attribute, covariant=True) -ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute] +ConstraintVariableType: TypeAlias = Attribute | Sequence[Attribute] | int """ Possible types that a constraint variable can have. """ @@ -651,6 +669,102 @@ def infer(self, context: InferenceContext) -> AttributeCovT: return self.constr.infer(context) +@dataclass(frozen=True) +class IntConstraint(ABC): + """Constrain an integer to certain values.""" + + @abstractmethod + def verify( + self, + i: int, + constraint_context: ConstraintContext, + ) -> None: + """ + Check if the integer satisfies the constraint, or raise an exception otherwise. + """ + ... + + def get_length_extractors( + self, + ) -> dict[str, VarExtractor[int]]: + """ + Get a dictionary of variables that can be solved from this attribute. + """ + return dict() + + def can_infer(self, var_constraint_names: Set[str]) -> bool: + """ + Check if there is enough information to infer the integer given the + constraint variables that are already set. + """ + # By default, we cannot infer anything. + return False + + def infer(self, context: InferenceContext) -> int: + """ + Infer the attribute given the the values for all variables. + + Raises an exception if the attribute cannot be inferred. If `can_infer` + returns `True` with the given constraint variables, this method should + not raise an exception. + """ + raise ValueError("Cannot infer attribute from constraint") + + +class AnyInt(IntConstraint): + """ + Constraint that is verified by all integers. + """ + + def verify(self, i: int, constraint_context: ConstraintContext) -> None: + pass + + +@dataclass(frozen=True) +class IntVarConstraint(IntConstraint): + """ + Constrain an integer with the given constraint, and constrain all occurences + of this constraint (i.e, sharing the same name) to be equal. + """ + + name: str + """The variable name. All uses of that name refer to the same variable.""" + + constraint: IntConstraint + """The constraint that the variable must satisfy.""" + + def verify( + self, + i: int, + constraint_context: ConstraintContext, + ) -> None: + if self.name in constraint_context.int_variables: + if i != constraint_context.get_int_variable(self.name): + raise VerifyException( + f"integer {constraint_context.get_int_variable(self.name)} expected from int variable " + f"'{self.name}', but got {i}" + ) + else: + self.constraint.verify(i, constraint_context) + constraint_context.set_int_variable(self.name, i) + + def get_length_extractors( + self, + ) -> dict[str, VarExtractor[int]]: + return {self.name: IdExtractor()} + + def can_infer(self, var_constraint_names: Set[str]) -> bool: + return self.name in var_constraint_names + + def infer( + self, + context: InferenceContext, + ) -> int: + v = context.variables[self.name] + assert isinstance(v, int) + return v + + @dataclass(frozen=True) class GenericRangeConstraint(Generic[AttributeCovT], ABC): """Constrain a range of attributes to certain values.""" @@ -677,6 +791,14 @@ def get_variable_extractors( """ return {} + def get_length_extractors( + self, + ) -> dict[str, VarExtractor[int]]: + """ + Get a dictionary of variables that can be solved using the length of the range. + """ + return dict() + def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool: """ Check if there is enough information to infer the attribute given the @@ -699,10 +821,6 @@ def infer( """ raise ValueError("Cannot infer attribute from constraint") - def get_unique_base(self) -> type[Attribute] | None: - """Get the unique base type that can satisfy the constraint, if any.""" - return None - RangeConstraint: TypeAlias = GenericRangeConstraint[Attribute] @@ -758,6 +876,8 @@ class RangeOf(GenericRangeConstraint[AttributeCovT]): """ constr: GenericAttrConstraint[AttributeCovT] + _: KW_ONLY + length: IntConstraint = field(default_factory=AnyInt) def verify( self, @@ -766,9 +886,13 @@ def verify( ) -> None: for a in attrs: self.constr.verify(a, constraint_context) + self.length.verify(len(attrs), constraint_context) + + def get_length_extractors(self) -> dict[str, VarExtractor[int]]: + return self.length.get_length_extractors() def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool: - return length_known and self.constr.can_infer(var_constraint_names) + return (length_known or self.length.can_infer(var_constraint_names)) and self.constr.can_infer(var_constraint_names) def infer( self, @@ -776,7 +900,8 @@ def infer( *, length: int | None, ) -> Sequence[AttributeCovT]: - assert length is not None + if length is None: + length = self.length.infer(context) attr = self.constr.infer(context) return (attr,) * length @@ -834,7 +959,7 @@ def range_constr_coercion( ) -> GenericRangeConstraint[AttributeCovT]: if isinstance(attr, GenericRangeConstraint): return attr - return RangeOf(attr_constr_coercion(attr)) + return RangeOf(attr_constr_coercion(attr), length=AnyInt()) def single_range_constr_coercion( diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index d6161c10a3..bbab85611d 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -231,6 +231,15 @@ def add_reserved_attrs_to_directive(self, elements: list[FormatDirective]): ) return + @dataclass(frozen=True) + class _OperandLengthResolver(VarExtractor[ParsingState]): + idx: int + inner: VarExtractor[int] + + def extract_var(self, a: ParsingState) -> ConstraintVariableType: + assert isinstance(ops := a.operands[self.idx], Sequence) + return self.inner.extract_var(len(ops)) + @dataclass(frozen=True) class _OperandResultExtractor(VarExtractor[ParsingState]): idx: int @@ -275,6 +284,12 @@ def extractors_by_name(self) -> dict[str, VarExtractor[ParsingState]]: """ extractor_dicts: list[dict[str, VarExtractor[ParsingState]]] = [] for i, (_, operand_def) in enumerate(self.op_def.operands): + extractor_dicts.append( + { + v: self._OperandLengthResolver(i, r) + for v, r in operand_def.constr.get_length_extractors().items() + } + ) if self.seen_operand_types[i]: extractor_dicts.append( { From 6135f83bdb5150760097a7fc3e361e3220bc81fb Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Mon, 6 Jan 2025 15:08:48 +0000 Subject: [PATCH 2/2] Formatting --- tests/irdl/test_declarative_assembly_format.py | 5 +++-- xdsl/irdl/constraints.py | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index cd6e33dd5a..8921cfb213 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -32,12 +32,14 @@ from xdsl.irdl import ( AllOf, AnyAttr, + AnyInt, AttrSizedOperandSegments, AttrSizedRegionSegments, AttrSizedResultSegments, BaseAttr, EqAttrConstraint, GenericAttrConstraint, + IntVarConstraint, IRDLOperation, ParamAttrConstraint, ParameterDef, @@ -49,6 +51,7 @@ VarOperand, VarOpResult, attr_def, + eq, irdl_attr_definition, irdl_op_definition, operand_def, @@ -67,8 +70,6 @@ var_result_def, var_successor_def, ) -from xdsl.irdl import eq -from xdsl.irdl.constraints import AnyInt, IntVarConstraint from xdsl.irdl.declarative_assembly_format import ( AttrDictDirective, FormatProgram, diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 861cafdaea..2bcec18196 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -892,7 +892,9 @@ def get_length_extractors(self) -> dict[str, VarExtractor[int]]: return self.length.get_length_extractors() def can_infer(self, var_constraint_names: Set[str], *, length_known: bool) -> bool: - return (length_known or self.length.can_infer(var_constraint_names)) and self.constr.can_infer(var_constraint_names) + return ( + length_known or self.length.can_infer(var_constraint_names) + ) and self.constr.can_infer(var_constraint_names) def infer( self,