Skip to content

Commit

Permalink
Add ability to include a plugin when creating driver (#740)
Browse files Browse the repository at this point in the history
`vmfbRunner` now takes an optional `extra_plugin` argument to load an
executable plugin while the driver is getting created.

This option might be used, for example, when loading a vmfb that has an
external dependency on a native shared library.

The implementation of this new feature takes advantage of the
pre-existing `iree.runtime.flags` feature and a new IREE python API
function. Normally, drivers are managed in a cache. However, setting a
flag to specify the plugin has no effect on existing drivers. The API
now has a function for creating a driver independent of the cache, to
guarantee that any flags are sure to take effect.

This PR also includes a fix for the problem of the CI using old cached
wheels for iree, as recommended by @monorimet.

---------

Signed-off-by: Dave Liddell <[email protected]>
Signed-off-by: daveliddell <[email protected]>
  • Loading branch information
daveliddell authored Jun 17, 2024
1 parent 815c857 commit e328124
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
pip install --no-compile -r ${{ github.workspace }}/iree-turbine/pytorch-cpu-requirements.txt
pip install --no-compile --pre --upgrade -r ${{ github.workspace }}/iree-turbine/requirements.txt
pip install --no-compile --pre -e ${{ github.workspace }}/iree-turbine[testing]
pip install --upgrade --pre --no-cache-dir iree-compiler iree-runtime -f https://iree.dev/pip-release-links.html
pip install --no-compile --pre --upgrade -e models -r models/requirements.txt
- name: Show current free memory
Expand Down
16 changes: 14 additions & 2 deletions models/turbine_models/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import argparse
import sys
from iree import runtime as ireert
from iree.runtime._binding import create_hal_driver


class vmfbRunner:
def __init__(self, device, vmfb_path, external_weight_path=None):
def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=None):
flags = []
haldriver = ireert.get_driver(device)

# If an extra plugin is requested, add a global flag to load the plugin
# and create the driver using the non-caching creation function, as
# the caching creation function may ignore the flag.
if extra_plugin:
ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
haldriver = create_hal_driver(device)

# No plugin requested: create the driver with the caching create
# function.
else:
haldriver = ireert.get_driver(device)
if "://" in device:
try:
device_idx = int(device.split("://")[-1])
Expand Down

0 comments on commit e328124

Please sign in to comment.