Skip to content

Commit

Permalink
Paralellize table generation
Browse files Browse the repository at this point in the history
  • Loading branch information
math-fehr committed Jan 9, 2024
1 parent 174a7d6 commit 01c53b4
Showing 1 changed file with 35 additions and 22 deletions.
57 changes: 35 additions & 22 deletions xdsl_pdl/tools/generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import concurrent.futures
import argparse
from os import cpu_count
from random import Random
from tabulate import tabulate

Check warning on line 9 in xdsl_pdl/tools/generate_table.py

View workflow job for this annotation

GitHub Actions / build (3.10)

Import "tabulate" could not be resolved from source (reportMissingModuleSource)

Expand Down Expand Up @@ -45,11 +47,17 @@ def fuzz_pdl_matches(
mlir_analysis = analyze_with_mlir(module.ops.first, ctx, randgen, mlir_executable_path)
return analysis_correct, mlir_analysis is None


class GenerateTableMain(xDSLOptMain):
num_tested: int
failed_analyses: int
values: list[list[int]]

def __init__(self):
super().__init__()
self.ctx.allow_unregistered = True
self.num_tested = 0
self.failed_analyses = 0
self.values = [[0, 0], [0, 0]]

def register_all_dialects(self):
super().register_all_dialects()
Expand All @@ -58,31 +66,36 @@ def register_all_dialects(self):
def register_all_arguments(self, arg_parser: argparse.ArgumentParser):
super().register_all_arguments(arg_parser)
arg_parser.add_argument("--mlir-executable", type=str, default="mlir-opt")
arg_parser.add_argument("--num-patterns", type=int, default=10000)
arg_parser.add_argument("-j", type=int, default=-1)
arg_parser.add_argument("-n", type=int, default=10000)
arg_parser.add_argument("-j", type=int, default=cpu_count())

def run_one_thread(self, seed: int):
randgen = Random()
randgen.seed(seed)
pattern = generate_random_pdl_rewrite(randgen)
module = ModuleOp([pattern])
test_res = fuzz_pdl_matches(module, self.ctx, randgen, self.args.mlir_executable)
self.num_tested += 1
print(f"Tested {self.num_tested} patterns", end="\r")
if test_res is None:
self.failed_analyses += 1
return
self.values[int(test_res[0])][int(test_res[1])] += 1

def run(self):
randgen = Random()
randgen.seed(42)
values = [[0, 0], [0, 0]]
failed_analyses = 0
for i in range(10000):
print(i)
pattern = generate_random_pdl_rewrite(randgen)
module = ModuleOp([pattern])
test_res = fuzz_pdl_matches(module, self.ctx, randgen, self.args.mlir_executable)
if test_res is None:
failed_analyses += 1
continue
values[int(test_res[0])][int(test_res[1])] += 1

print("Analysis failed, MLIR execution failed: ", values[0][0])
print("Analysis failed, MLIR execution succeeded: ", values[0][1])
print("Analysis succeeded, MLIR execution failed: ", values[1][0])
print("Analysis succeeded, MLIR execution succeeded: ", values[1][1])
print("PDL Analysis raised an exception: ", failed_analyses)

print_results(values[0][0], values[0][1], values[1][0], values[1][1])
seeds = [randgen.randint(0, 2 ** 30) for _ in range(self.args.n)]
with concurrent.futures.ThreadPoolExecutor(max_workers=self.args.j) as executor:
executor.map(self.run_one_thread, seeds)

print("Analysis failed, MLIR execution failed: ", self.values[0][0])
print("Analysis failed, MLIR execution succeeded: ", self.values[0][1])
print("Analysis succeeded, MLIR execution failed: ", self.values[1][0])
print("Analysis succeeded, MLIR execution succeeded: ", self.values[1][1])
print("PDL Analysis raised an exception: ", self.failed_analyses)

print_results(self.values[0][0], self.values[0][1], self.values[1][0], self.values[1][1])


def print_results(
Expand Down

0 comments on commit 01c53b4

Please sign in to comment.