Skip to content

Commit

Permalink
core: Implement IRDL SameSize options (#3067)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
PapyChacal and superlopuh authored Aug 20, 2024
1 parent 77ae3a5 commit 8fc70cd
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 6 deletions.
40 changes: 40 additions & 0 deletions tests/irdl/test_operation_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import re

import pytest

from xdsl.dialects.arith import Constant
Expand All @@ -17,6 +19,7 @@
OptOpResult,
OptRegion,
OptSuccessor,
SameVariadicOperandSize,
Successor,
VarOperand,
VarOpResult,
Expand All @@ -41,6 +44,7 @@
var_successor_def,
)
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException

################################################################################
# Results #
Expand Down Expand Up @@ -312,6 +316,42 @@ def test_two_var_operand_prop_builder2():
] == DenseArrayBase.from_list(i32, [1, 3])


@irdl_op_definition
class SameSizeVarOperandOp(IRDLOperation):
name = "test.same_size_var_operand_op"

var1 = var_operand_def()
op1 = operand_def()
var2 = var_operand_def()
irdl_options = [SameVariadicOperandSize()]


def test_same_size_operand_builder():
op1 = ResultOp.build(result_types=[StringAttr("0")]).res
op2 = SameSizeVarOperandOp.build(operands=[[op1, op1], op1, [op1, op1]])
op2.verify()
assert tuple(op2.operands) == (op1, op1, op1, op1, op1)
op2 = SameSizeVarOperandOp.create(operands=[op1, op1, op1, op1, op1])
op2.verify()
assert (op2.var1, op2.op1, op2.var2) == ((op1, op1), op1, (op1, op1))


def test_same_size_operand_builder2():
op1 = ResultOp.build(result_types=[StringAttr("0")]).res
with pytest.raises(
ValueError, match=re.escape("Variadic operands have different sizes: [1, 3]")
):
SameSizeVarOperandOp.build(operands=[[op1], op1, [op1, op1, op1]])
op2 = SameSizeVarOperandOp.create(operands=[op1, op1, op1, op1])
with pytest.raises(
VerifyException,
match=re.escape(
"Operation does not verify: Operation has 3 operands for 2 variadic operands marked as having the same size."
),
):
op2.verify()


################################################################################
# Attribute #
################################################################################
Expand Down
15 changes: 15 additions & 0 deletions tests/irdl/test_operation_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,18 @@ def test_entry_args_op():
Expected attribute i32 but got i64""",
):
op.verify()


class OptionlessMultipleVarOp(IRDLOperation):
name = "test.multiple_var_op"

optional = opt_operand_def()
variadic = var_operand_def()


def test_no_multiple_var_option():
with pytest.raises(
PyRDLOpDefinitionError,
match="Operation test.multiple_var_op defines more than two variadic operands, but do not define any of SameVariadicOperandSize or AttrSizedOperandSegments PyRDL options.",
):
irdl_op_definition(OptionlessMultipleVarOp)
115 changes: 109 additions & 6 deletions xdsl/irdl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,36 @@ class AttrSizedSuccessorSegments(AttrSizedSegments):
"""Name of the attribute containing the variadic successor sizes."""


class SameVariadicSize(IRDLOption):
"""
All variadic definitions should have the same size.
"""


class SameVariadicResultSize(SameVariadicSize):
"""
All variadic results should have the same size.
"""


class SameVariadicOperandSize(SameVariadicSize):
"""
All variadic operands should have the same size.
"""


class SameVariadicRegionSize(SameVariadicSize):
"""
All variadic regions should have the same size.
"""


class SameVariadicSuccessorSize(SameVariadicSize):
"""
All variadic successors should have the same size.
"""


@dataclass
class ParsePropInAttrDict(IRDLOption):
"""
Expand Down Expand Up @@ -1222,6 +1252,32 @@ def get_attr_size_option(
assert False, "Unknown VarIRConstruct value"


def get_same_variadic_size_option(
construct: VarIRConstruct,
) -> type[
SameVariadicOperandSize
| SameVariadicResultSize
| SameVariadicRegionSize
| SameVariadicSuccessorSize
]:
"""Get the AttrSized option for this type."""
if construct == VarIRConstruct.OPERAND:
return SameVariadicOperandSize
if construct == VarIRConstruct.RESULT:
return SameVariadicResultSize
if construct == VarIRConstruct.REGION:
return SameVariadicRegionSize
if construct == VarIRConstruct.SUCCESSOR:
return SameVariadicSuccessorSize
assert False, "Unknown VarIRConstruct value"


def get_multiple_variadic_options(
construct: VarIRConstruct,
) -> list[type[IRDLOption]]:
return [get_same_variadic_size_option(construct), get_attr_size_option(construct)]


def get_variadic_sizes_from_attr(
op: Operation,
defs: Sequence[tuple[str, OperandDef | ResultDef | RegionDef | SuccessorDef]],
Expand Down Expand Up @@ -1294,6 +1350,7 @@ def get_variadic_sizes(
args = get_op_constructs(op, construct)
def_type_name = get_construct_name(construct)
attribute_option = get_attr_size_option(construct)
same_size_option = get_same_variadic_size_option(construct)

variadic_defs = [
(arg_name, arg_def)
Expand Down Expand Up @@ -1331,6 +1388,19 @@ def get_variadic_sizes(
)
return [len(args) - len(defs) + 1]

# If the operation has to related SameSize option, equally distribute the
# variadic arguments between the variadic definitions.
option = next((o for o in op_def.options if isinstance(o, same_size_option)), None)
if option is not None:
non_variadic_defs = len(defs) - len(variadic_defs)
variadic_args = len(args) - non_variadic_defs
if variadic_args % len(variadic_defs):
name = get_construct_name(construct)
raise VerifyException(
f"Operation has {variadic_args} {name}s for {len(variadic_defs)} variadic {name}s marked as having the same size."
)
return [variadic_args // len(variadic_defs)] * len(variadic_defs)

# Unreachable, all cases should have been handled.
# Additional cases should raise an exception upon
# definition of the irdl operation.
Expand Down Expand Up @@ -1529,6 +1599,7 @@ def irdl_build_arg_list(
error_prefix
+ f"passed None to a non-optional {construct} {arg_idx} '{arg_name}'"
)
arg_sizes.append(0)
elif isinstance(arg, Sequence):
if not isinstance(arg_def, VariadicDef):
raise ValueError(
Expand Down Expand Up @@ -1698,6 +1769,35 @@ def irdl_op_init(
raise ValueError(
f"Unexpected option {option} in operation definition {op_def}."
)
case SameVariadicSize():
match option:
case SameVariadicOperandSize():
sizes = operand_sizes
construct = VarIRConstruct.OPERAND
case SameVariadicResultSize():
sizes = result_sizes
construct = VarIRConstruct.RESULT
case SameVariadicRegionSize():
sizes = region_sizes
construct = VarIRConstruct.REGION
case SameVariadicSuccessorSize():
sizes = successor_sizes
construct = VarIRConstruct.SUCCESSOR
case _:
raise ValueError(
f"Unexpected option {option} in operation definition {op_def}."
)
variadic_sizes = [
size
for (size, def_) in zip(
sizes, get_construct_defs(op_def, construct)
)
if isinstance(def_[1], VariadicDef)
]
if any(size != variadic_sizes[0] for size in variadic_sizes[1:]):
raise ValueError(
f"Variadic {get_construct_name(construct)}s have different sizes: {variadic_sizes}"
)
case _:
pass

Expand Down Expand Up @@ -1730,15 +1830,18 @@ def fun(self: Any, idx: int = arg_idx, previous_vars: int = previous_variadics):

# If we have multiple variadics, check that we have an
# attribute that holds the variadic sizes.
arg_size_option = get_attr_size_option(construct)
variadics_option = get_multiple_variadic_options(construct)
if previous_variadics > 1 and (
not any(isinstance(o, arg_size_option) for o in op_def.options)
not any(
isinstance(o, option) for o in op_def.options for option in variadics_option
)
):
arg_size_option_name = type(arg_size_option).__name__
raise Exception(
names = list(option.__name__ for option in variadics_option)
names, last_name = names[:-1], names[-1]
raise PyRDLOpDefinitionError(
f"Operation {op_def.name} defines more than two variadic "
f"{get_construct_name(construct)}s, but do not define the "
f"{arg_size_option_name} PyRDL option."
f"{get_construct_name(construct)}s, but do not define any of "
f"{', '.join(names)} or {last_name} PyRDL options."
)


Expand Down

0 comments on commit 8fc70cd

Please sign in to comment.