From 01c53b47b7c0aa418a64cbbca033f95bdb566fa4 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Tue, 9 Jan 2024 16:54:57 +0000 Subject: [PATCH] Paralellize table generation --- xdsl_pdl/tools/generate_table.py | 57 ++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/xdsl_pdl/tools/generate_table.py b/xdsl_pdl/tools/generate_table.py index 3942285..9173cf4 100644 --- a/xdsl_pdl/tools/generate_table.py +++ b/xdsl_pdl/tools/generate_table.py @@ -2,7 +2,9 @@ from __future__ import annotations +import concurrent.futures import argparse +from os import cpu_count from random import Random from tabulate import tabulate @@ -45,11 +47,17 @@ def fuzz_pdl_matches( mlir_analysis = analyze_with_mlir(module.ops.first, ctx, randgen, mlir_executable_path) return analysis_correct, mlir_analysis is None - class GenerateTableMain(xDSLOptMain): + num_tested: int + failed_analyses: int + values: list[list[int]] + def __init__(self): super().__init__() self.ctx.allow_unregistered = True + self.num_tested = 0 + self.failed_analyses = 0 + self.values = [[0, 0], [0, 0]] def register_all_dialects(self): super().register_all_dialects() @@ -58,31 +66,36 @@ def register_all_dialects(self): def register_all_arguments(self, arg_parser: argparse.ArgumentParser): super().register_all_arguments(arg_parser) arg_parser.add_argument("--mlir-executable", type=str, default="mlir-opt") - arg_parser.add_argument("--num-patterns", type=int, default=10000) - arg_parser.add_argument("-j", type=int, default=-1) + arg_parser.add_argument("-n", type=int, default=10000) + arg_parser.add_argument("-j", type=int, default=cpu_count()) + + def run_one_thread(self, seed: int): + randgen = Random() + randgen.seed(seed) + pattern = generate_random_pdl_rewrite(randgen) + module = ModuleOp([pattern]) + test_res = fuzz_pdl_matches(module, self.ctx, randgen, self.args.mlir_executable) + self.num_tested += 1 + print(f"Tested {self.num_tested} patterns", end="\r") + if test_res is None: + self.failed_analyses += 1 + return + self.values[int(test_res[0])][int(test_res[1])] += 1 def run(self): randgen = Random() randgen.seed(42) - values = [[0, 0], [0, 0]] - failed_analyses = 0 - for i in range(10000): - print(i) - pattern = generate_random_pdl_rewrite(randgen) - module = ModuleOp([pattern]) - test_res = fuzz_pdl_matches(module, self.ctx, randgen, self.args.mlir_executable) - if test_res is None: - failed_analyses += 1 - continue - values[int(test_res[0])][int(test_res[1])] += 1 - - print("Analysis failed, MLIR execution failed: ", values[0][0]) - print("Analysis failed, MLIR execution succeeded: ", values[0][1]) - print("Analysis succeeded, MLIR execution failed: ", values[1][0]) - print("Analysis succeeded, MLIR execution succeeded: ", values[1][1]) - print("PDL Analysis raised an exception: ", failed_analyses) - - print_results(values[0][0], values[0][1], values[1][0], values[1][1]) + seeds = [randgen.randint(0, 2 ** 30) for _ in range(self.args.n)] + with concurrent.futures.ThreadPoolExecutor(max_workers=self.args.j) as executor: + executor.map(self.run_one_thread, seeds) + + print("Analysis failed, MLIR execution failed: ", self.values[0][0]) + print("Analysis failed, MLIR execution succeeded: ", self.values[0][1]) + print("Analysis succeeded, MLIR execution failed: ", self.values[1][0]) + print("Analysis succeeded, MLIR execution succeeded: ", self.values[1][1]) + print("PDL Analysis raised an exception: ", self.failed_analyses) + + print_results(self.values[0][0], self.values[0][1], self.values[1][0], self.values[1][1]) def print_results(