Skip to content

Commit

Permalink
Allow to use a custom Random generator
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 9, 2024
1 parent da274c1 commit 174a7d6
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 35 deletions.
8 changes: 4 additions & 4 deletions xdsl_pdl/analysis/mlir_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import subprocess
from io import StringIO
from dataclasses import dataclass
from random import randrange
from random import Random

from xdsl.ir import MLContext, Operation, Region, Block
from xdsl.printer import Printer
Expand Down Expand Up @@ -75,7 +75,7 @@ def run_with_mlir(


def analyze_with_mlir(
pattern: PatternOp, ctx: MLContext, mlir_executable_path: str
pattern: PatternOp, ctx: MLContext, randgen: Random, mlir_executable_path: str
) -> MLIRFailure | MLIRInfiniteLoop | None:
"""
Run the pattern on multiple examples with MLIR.
Expand All @@ -85,8 +85,8 @@ def analyze_with_mlir(
all_dags = generate_all_dags(5)
try:
for _ in range(0, 10):
region, ops = pdl_to_operations(pattern, ctx)
dag = all_dags[randrange(0, len(all_dags))]
region, ops = pdl_to_operations(pattern, ctx, randgen)
dag = all_dags[randgen.randrange(0, len(all_dags))]
create_dag_in_region(region, dag, ctx)
for populated_region in put_operations_in_region(dag, region, ops):
cloned_region = Region()
Expand Down
8 changes: 4 additions & 4 deletions xdsl_pdl/fuzzing/generate_pdl_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from itertools import chain, combinations
from dataclasses import dataclass, field
from random import randrange
from random import Random, randrange
from typing import Generator, Iterable, cast

from xdsl.ir import Attribute, Block, MLContext, OpResult, Operation, Region, SSAValue
Expand Down Expand Up @@ -129,7 +129,7 @@ def possible_values_of_type(self, type: Attribute) -> list[SSAValue]:


def pdl_to_operations(
pattern: PatternOp, ctx: MLContext
pattern: PatternOp, ctx: MLContext, randgen: Random
) -> tuple[Region, list[Operation]]:
pattern_ops = pattern.body.ops
region = Region([Block()])
Expand Down Expand Up @@ -160,7 +160,7 @@ def pdl_to_operations(
possible_values.extend(
[arg for arg in region_args if arg.type == operand_type]
)
choice = randrange(0, len(possible_values) + 1)
choice = randgen.randrange(0, len(possible_values) + 1)
if choice == len(possible_values):
arg = region.blocks[0].insert_arg(operand_type, 0)
else:
Expand Down Expand Up @@ -260,7 +260,7 @@ def rec(i: int, ops: list[Operation]) -> Generator[Region, None, None]:
block = region.blocks[i + 1]
assert block.ops.last is not None
block.insert_op_before(ops[0], block.ops.last)
yield from rec(i + 1, ops[1:])
yield from rec(i, ops[1:])
ops[0].detach()

yield from rec(0, ops)
33 changes: 17 additions & 16 deletions xdsl_pdl/fuzzing/generate_pdl_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from random import randrange
from random import Random

from xdsl.ir import Block, Operation, Region, SSAValue

Expand Down Expand Up @@ -35,25 +35,26 @@ class _FuzzerOptions:

@dataclass
class _FuzzerContext:
randgen: Random
values: list[SSAValue] = field(default_factory=list)
operations: list[OperationOp] = field(default_factory=list)

def get_random_value(self) -> SSAValue:
assert len(self.values) != 0
return self.values[randrange(0, len(self.values))]
return self.values[self.randgen.randrange(0, len(self.values))]

def get_random_operation(self) -> OperationOp:
assert len(self.operations) != 0
return self.operations[randrange(0, len(self.operations))]
return self.operations[self.randgen.randrange(0, len(self.operations))]


def _generate_random_operand(ctx: _FuzzerContext) -> tuple[SSAValue, list[Operation]]:
"""
Generate a random operand.
It is either a new `pdl.operand`, or an existing one in the context.
"""
if len(ctx.values) != 0 and randrange(0, 2) == 0:
return ctx.values[randrange(0, len(ctx.values))], []
if len(ctx.values) != 0 and ctx.randgen.randrange(0, 2) == 0:
return ctx.values[ctx.randgen.randrange(0, len(ctx.values))], []
new_type = TypeOp.create(
result_types=[TypeType()], attributes={"constantType": i32}
)
Expand All @@ -68,10 +69,10 @@ def _generate_random_matched_operation(ctx: _FuzzerContext) -> list[Operation]:
Generate a random `pdl.operation`, along with new
`pdl.operand` and `pdl.type` if necessary.
"""
num_operands = randrange(
num_operands = ctx.randgen.randrange(
_FuzzerOptions.min_operands, _FuzzerOptions.max_operands + 1
)
num_results = randrange(_FuzzerOptions.min_results, _FuzzerOptions.max_results + 1)
num_results = ctx.randgen.randrange(_FuzzerOptions.min_results, _FuzzerOptions.max_results + 1)
new_ops: list[Operation] = []

operands: list[SSAValue] = []
Expand Down Expand Up @@ -104,7 +105,7 @@ def _generate_random_rewrite_operation(ctx: _FuzzerContext) -> list[Operation]:
Generate a random operation in the rewrite part of the pattern.
This can be either an `operation`, an `erase`, or a `replace`.
"""
operation_choice = randrange(0, 4)
operation_choice = ctx.randgen.randrange(0, 4)

# Erase operation
if operation_choice == 0:
Expand All @@ -128,10 +129,10 @@ def _generate_random_rewrite_operation(ctx: _FuzzerContext) -> list[Operation]:

# Create a new operation
assert operation_choice == 3
num_operands = randrange(
num_operands = ctx.randgen.randrange(
_FuzzerOptions.min_operands, _FuzzerOptions.max_operands + 1
)
num_results = randrange(_FuzzerOptions.min_results, _FuzzerOptions.max_results + 1)
num_results = ctx.randgen.randrange(_FuzzerOptions.min_results, _FuzzerOptions.max_results + 1)

# If we need values but we don't have, we restart
if num_operands != 0 and len(ctx.values) == 0:
Expand Down Expand Up @@ -159,15 +160,15 @@ def _generate_random_rewrite_operation(ctx: _FuzzerContext) -> list[Operation]:
return new_ops


def generate_unverified_random_pdl_rewrite() -> PatternOp:
def generate_unverified_random_pdl_rewrite(randgen: Random) -> PatternOp:
"""
Generate a random match part of a `pdl.rewrite`.
"""
ctx = _FuzzerContext()
num_matched_operations = randrange(
ctx = _FuzzerContext(randgen)
num_matched_operations = randgen.randrange(
_FuzzerOptions.min_match_operations, _FuzzerOptions.max_match_operations + 1
)
num_rewrite_operations = randrange(
num_rewrite_operations = randgen.randrange(
_FuzzerOptions.min_rewrite_operations, _FuzzerOptions.max_rewrite_operations + 1
)

Expand All @@ -192,9 +193,9 @@ def generate_unverified_random_pdl_rewrite() -> PatternOp:
return PatternOp(1, None, body)


def generate_random_pdl_rewrite() -> PatternOp:
def generate_random_pdl_rewrite(randgen: Random) -> PatternOp:
while True:
pattern = generate_unverified_random_pdl_rewrite()
pattern = generate_unverified_random_pdl_rewrite(randgen)
try:
pattern.verify()
except Exception:
Expand Down
3 changes: 2 additions & 1 deletion xdsl_pdl/tools/analyze_pdl_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import argparse
from random import Random
from xdsl.dialects.builtin import ModuleOp

from xdsl.printer import Printer
Expand Down Expand Up @@ -29,7 +30,7 @@ def register_all_dialects(self):

def run(self):
if self.args.input_file is None:
pattern = generate_random_pdl_rewrite()
pattern = generate_random_pdl_rewrite(Random())
module = ModuleOp([pattern])
else:
chunks, extension = self.prepare_input()
Expand Down
5 changes: 3 additions & 2 deletions xdsl_pdl/tools/generate_pdl_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import argparse
from random import Random

from xdsl.ir import MLContext
from xdsl.utils.diagnostic import Diagnostic
Expand Down Expand Up @@ -51,7 +52,7 @@ def fuzz_pdl_matches(module: ModuleOp, ctx: MLContext, mlir_executable_path: str
printer = Printer(diagnostic=diagnostic)
printer.print_op(module)

mlir_analysis = analyze_with_mlir(module.ops.first, ctx, mlir_executable_path)
mlir_analysis = analyze_with_mlir(module.ops.first, ctx, Random(), mlir_executable_path)
if mlir_analysis is None:
print("MLIR analysis succeeded")
else:
Expand Down Expand Up @@ -91,7 +92,7 @@ def register_all_dialects(self):

def run(self):
if self.args.input_file is None:
pattern = generate_random_pdl_rewrite()
pattern = generate_random_pdl_rewrite(Random())
module = ModuleOp([pattern])
else:
chunks, extension = self.prepare_input()
Expand Down
3 changes: 2 additions & 1 deletion xdsl_pdl/tools/generate_pdl_rewrite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from random import Random
from xdsl.xdsl_opt_main import xDSLOptMain

from xdsl_pdl.fuzzing.generate_pdl_rewrite import generate_random_pdl_rewrite
Expand All @@ -21,7 +22,7 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser):
pass

def run(self):
pattern = generate_random_pdl_rewrite()
pattern = generate_random_pdl_rewrite(Random())
module = ModuleOp([pattern])
output_stream = self.prepare_output()
output_stream.write(self.output_resulting_program(module))
Expand Down
17 changes: 10 additions & 7 deletions xdsl_pdl/tools/generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import argparse
import random
from random import Random
from tabulate import tabulate

from xdsl.ir import MLContext
Expand All @@ -24,7 +24,7 @@


def fuzz_pdl_matches(
module: ModuleOp, ctx: MLContext, mlir_executable_path: str
module: ModuleOp, ctx: MLContext, randgen: Random, mlir_executable_path: str
) -> tuple[bool, bool] | None:
"""
Returns the result of the PDL analysis, and the result of the analysis using
Expand All @@ -42,7 +42,7 @@ def fuzz_pdl_matches(
except Exception:
return None

mlir_analysis = analyze_with_mlir(module.ops.first, ctx, mlir_executable_path)
mlir_analysis = analyze_with_mlir(module.ops.first, ctx, randgen, mlir_executable_path)
return analysis_correct, mlir_analysis is None


Expand All @@ -57,17 +57,20 @@ def register_all_dialects(self):

def register_all_arguments(self, arg_parser: argparse.ArgumentParser):
super().register_all_arguments(arg_parser)
arg_parser.add_argument("--mlir-executable", type=str, required=True)
arg_parser.add_argument("--mlir-executable", type=str, default="mlir-opt")
arg_parser.add_argument("--num-patterns", type=int, default=10000)
arg_parser.add_argument("-j", type=int, default=-1)

def run(self):
random.seed(42)
randgen = Random()
randgen.seed(42)
values = [[0, 0], [0, 0]]
failed_analyses = 0
for i in range(10000):
print(i)
pattern = generate_random_pdl_rewrite()
pattern = generate_random_pdl_rewrite(randgen)
module = ModuleOp([pattern])
test_res = fuzz_pdl_matches(module, self.ctx, self.args.mlir_executable)
test_res = fuzz_pdl_matches(module, self.ctx, randgen, self.args.mlir_executable)
if test_res is None:
failed_analyses += 1
continue
Expand Down

0 comments on commit 174a7d6

Please sign in to comment.