Skip to content

Commit

Permalink
scripts/vsmlrt.py: add MIGraphX support; reorder prologue for plugin …
Browse files Browse the repository at this point in the history
…path
  • Loading branch information
WolframRhodium committed Feb 29, 2024
1 parent 00a2186 commit 4962ff6
Showing 1 changed file with 198 additions and 5 deletions.
203 changes: 198 additions & 5 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.18.23"
__version__ = "3.19.0"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -35,22 +35,26 @@ def get_plugins_path() -> str:
path = b""

try:
path = core.trt.Version()["path"]
path = core.ov.Version()["path"]
except AttributeError:
try:
path = core.ort.Version()["path"]
except AttributeError:
try:
path = core.ov.Version()["path"]
except AttributeError:
path = core.ncnn.Version()["path"]
except AttributeError:
try:
path = core.trt.Version()["path"]
except AttributeError:
path = core.migraphx.Version()["path"]

assert path != b""

return os.path.dirname(path).decode()

plugins_path: str = get_plugins_path()
trtexec_path: str = os.path.join(plugins_path, "vsmlrt-cuda", "trtexec")
migraphx_driver_path: str = os.path.join(plugins_path, "vsmlrt-hip", "migraphx-driver")
models_path: str = os.path.join(plugins_path, "models")


Expand Down Expand Up @@ -194,6 +198,27 @@ class ORT_DML:
# internal backend attributes
supports_onnx_serialization: bool = True

@dataclass(frozen=False)
class MIGraphX:
""" backend for amd gpus
basic performance tuning:
set fp16 = True
"""

device_id: int = 0
fp16: bool = False
opt_shapes: typing.Optional[typing.Tuple[int, int]] = None
fast_math: bool = True
exhaustive_tune: bool = False

short_path: typing.Optional[bool] = None # True on Windows by default, False otherwise
custom_env: typing.Dict[str, str] = field(default_factory=lambda: {})
custom_args: typing.List[str] = field(default_factory=lambda: [])

# internal backend attributes
supports_onnx_serialization: bool = False


backendT = typing.Union[
Backend.OV_CPU,
Expand All @@ -203,6 +228,7 @@ class ORT_DML:
Backend.OV_GPU,
Backend.NCNN_VK,
Backend.ORT_DML,
Backend.MIGraphX
]


Expand Down Expand Up @@ -1664,11 +1690,132 @@ def trtexec(
else:
env = {"CUDA_MODULE_LOADING": "LAZY"}
env.update(**custom_env)
subprocess.run(args, env=custom_env, check=True, stdout=sys.stderr)
subprocess.run(args, env=env, check=True, stdout=sys.stderr)

return engine_path


def get_program_path(
network_path: str,
opt_shapes: typing.Tuple[int, int],
fp16: bool,
fast_math: bool,
exhaustive_tune: bool,
device_id: int,
short_path: typing.Optional[bool]
) -> str:

with open(network_path, "rb") as file:
checksum = zlib.adler32(file.read())

migraphx_version = core.migraphx.Version()["migraphx_version_build"].decode()

try:
device_name = core.migraphx.DeviceProperties(device_id)["name"].decode()
device_name = device_name.replace(' ', '-')
except AttributeError:
device_name = f"device{device_id}"

shape_str = f"{opt_shapes[0]}x{opt_shapes[1]}"

identity = (
shape_str +
("_fp16" if fp16 else "") +
("_fast" if fast_math else "") +
("_exhaustive" if exhaustive_tune else "") +
f"_migraphx-{migraphx_version}" +
f"_{device_name}" +
f"_{checksum:x}"
)

if short_path or (short_path is None and platform.system() == "Windows"):
dirname, basename = os.path.split(network_path)
return os.path.join(dirname, f"{zlib.crc32((basename + identity).encode()):x}.program")
else:
return f"{network_path}.{identity}.program"


def migraphx_driver(
network_path: str,
channels: int,
opt_shapes: typing.Tuple[int, int],
fp16: bool,
fast_math: bool,
exhaustive_tune: bool,
device_id: int,
input_name: str = "input",
short_path: typing.Optional[bool] = None,
custom_env: typing.Dict[str, str] = {},
custom_args: typing.List[str] = []
) -> str:

if isinstance(opt_shapes, int):
opt_shapes = (opt_shapes, opt_shapes)

program_path = get_program_path(
network_path=network_path,
opt_shapes=opt_shapes,
fp16=fp16,
fast_math=fast_math,
exhaustive_tune=exhaustive_tune,
device_id=device_id,
short_path=short_path
)

if os.access(program_path, mode=os.R_OK):
return program_path

alter_program_path = os.path.join(
tempfile.gettempdir(),
os.path.splitdrive(program_path)[1][1:]
)

if os.access(alter_program_path, mode=os.R_OK):
return alter_program_path

try:
# test writability
with open(program_path, "w") as f:
pass
os.remove(program_path)
except PermissionError:
print(f"{program_path} not writable", file=sys.stderr)
program_path = alter_program_path
dirname = os.path.dirname(program_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
print(f"change program path to {program_path}", file=sys.stderr)

if device_id != 0:
raise ValueError('"device_id" must be 0')

args = [
migraphx_driver_path,
"compile",
"--onnx", f"{network_path}",
"--gpu",
# f"--device={device_id}",
"--output", f"{program_path}"
]

args.extend(["--input-dim", f"@{input_name}", "1", f"{channels}", f"{opt_shapes[1]}", f"{opt_shapes[0]}"])

if fp16:
args.append("--fp16")

if not fast_math:
args.append("--disable-fast-math")

if exhaustive_tune:
args.append("--exhaustive-tune")

args.extend(custom_args)

subprocess.run(args, env=custom_env, check=True, stdout=sys.stderr)

return program_path


def calc_size(width: int, tiles: int, overlap: int, multiple: int = 1) -> int:
return math.ceil((width + 2 * overlap * (tiles - 1)) / (tiles * multiple)) * multiple

Expand Down Expand Up @@ -1723,6 +1870,8 @@ def init_backend(
backend = Backend.NCNN_VK()
elif backend is Backend.ORT_DML: # type: ignore
backend = Backend.ORT_DML()
elif backend is Backend.MIGraphX: # type: ignore
backend = Backend.MIGraphX()

backend = copy.deepcopy(backend)

Expand All @@ -1732,6 +1881,9 @@ def init_backend(

if backend.max_shapes is None:
backend.max_shapes = backend.opt_shapes
elif isinstance(backend, Backend.MIGraphX):
if backend.opt_shapes is None:
backend.opt_shapes = trt_opt_shapes

return backend

Expand Down Expand Up @@ -1885,6 +2037,35 @@ def _inference(
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
)
elif isinstance(backend, Backend.MIGraphX):
if path_is_serialization:
raise ValueError('"path_is_serialization" must be False for migraphx backend')

network_path = typing.cast(str, network_path)

channels = sum(clip.format.num_planes for clip in clips)

opt_shapes = backend.opt_shapes if backend.opt_shapes is not None else tilesize

program_path = migraphx_driver(
network_path,
channels=channels,
opt_shapes=opt_shapes,
fp16=backend.fp16,
fast_math=backend.fast_math,
exhaustive_tune=backend.exhaustive_tune,
device_id=backend.device_id,
input_name=input_name,
short_path=backend.short_path,
custom_env=backend.custom_env,
custom_args=backend.custom_args
)
clip = core.migraphx.Model(
clips, program_path,
overlap=overlap,
tilesize=tilesize,
device_id=backend.device_id
)
else:
raise TypeError(f'unknown backend {backend}')

Expand Down Expand Up @@ -2101,6 +2282,18 @@ def ORT_DML(*,
**kwargs
)

@staticmethod
def MIGraphX(*,
fp16: bool = False,
opt_shapes: typing.Optional[typing.Tuple[int, int]] = None,
**kwargs
) -> Backend.MIGraphX:

return Backend.MIGraphX(
fp16=fp16,
opt_shapes=opt_shapes
**kwargs
)

def fmtc_resample(clip: vs.VideoNode, **kwargs) -> vs.VideoNode:
clip_org = clip
Expand Down

0 comments on commit 4962ff6

Please sign in to comment.