From e328124f596da4757eee8c8a5b1af7831ec2648b Mon Sep 17 00:00:00 2001 From: Dave Liddell Date: Sun, 16 Jun 2024 23:05:46 -0600 Subject: [PATCH] Add ability to include a plugin when creating driver (#740) `vmfbRunner` now takes an optional `extra_plugin` argument to load an executable plugin while the driver is getting created. This option might be used, for example, when loading a vmfb that has an external dependency on a native shared library. The implementation of this new feature takes advantage of the pre-existing `iree.runtime.flags` feature and a new IREE python API function. Normally, drivers are managed in a cache. However, setting a flag to specify the plugin has no effect on existing drivers. The API now has a function for creating a driver independent of the cache, to guarantee that any flags are sure to take effect. This PR also includes a fix for the problem of the CI using old cached wheels for iree, as recommended by @monorimet. --------- Signed-off-by: Dave Liddell Signed-off-by: daveliddell --- .github/workflows/test_models.yml | 1 + models/turbine_models/model_runner.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 5e62f68e1..b7facb903 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -52,6 +52,7 @@ jobs: pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing] + pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html pip install --no-compile --pre --upgrade -e models -r models/requirements.txt - name: Show current free memory diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index bdc81bcf8..a173f3166 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -1,12 +1,24 @@ import argparse import sys from iree import runtime as ireert +from iree.runtime._binding import create_hal_driver class vmfbRunner: - def __init__(self, device, vmfb_path, external_weight_path=None): + def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None): flags = [] - haldriver = ireert.get_driver(device) + + # If an extra plugin is requested, add a global flag to load the plugin + # and create the driver using the non-caching creation function, as + # the caching creation function may ignore the flag. + if extra_plugin: + ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}") + haldriver = create_hal_driver(device) + + # No plugin requested: create the driver with the caching create + # function. + else: + haldriver = ireert.get_driver(device) if "://" in device: try: device_idx = int(device.split("://")[-1])