From d39a12c118ba7900c5d785d71d3d292eec38f5f7 Mon Sep 17 00:00:00 2001 From: Dave Liddell Date: Fri, 14 Jun 2024 11:18:56 -0700 Subject: [PATCH] Add ability to include a plugin when creating driver Signed-off-by: Dave Liddell --- models/turbine_models/model_runner.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index d49fa8362..a173f3166 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -1,16 +1,24 @@ import argparse import sys from iree import runtime as ireert +from iree.runtime._binding import create_hal_driver class vmfbRunner: def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None): flags = [] - clean_driver = False + + # If an extra plugin is requested, add a global flag to load the plugin + # and create the driver using the non-caching creation function, as + # the caching creation function may ignore the flag. if extra_plugin: ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") - clean_driver = True - haldriver = ireert.get_driver(device, clean_driver) + haldriver = create_hal_driver(device) + + # No plugin requested: create the driver with the caching create + # function. + else: + haldriver = ireert.get_driver(device) if "://" in device: try: device_idx = int(device.split("://")[-1])