Skip to content

Commit

Permalink
Add ability to include a plugin when creating driver
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Liddell <[email protected]>
  • Loading branch information
daveliddell committed Jun 14, 2024
1 parent fdaec4b commit d39a12c
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions models/turbine_models/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
import argparse
import sys
from iree import runtime as ireert
from iree.runtime._binding import create_hal_driver


class vmfbRunner:
def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
flags = []
clean_driver = False

# If an extra plugin is requested, add a global flag to load the plugin
# and create the driver using the non-caching creation function, as
# the caching creation function may ignore the flag.
if extra_plugin:
ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
clean_driver = True
haldriver = ireert.get_driver(device, clean_driver)
haldriver = create_hal_driver(device)

# No plugin requested: create the driver with the caching create
# function.
else:
haldriver = ireert.get_driver(device)
if "://" in device:
try:
device_idx = int(device.split("://")[-1])
Expand Down

0 comments on commit d39a12c

Please sign in to comment.