diff --git a/orttraining/orttraining/python/training/ort_triton/_cache.py b/orttraining/orttraining/python/training/ort_triton/_cache.py index ede9cd86a9da5..b70064377abfc 100644 --- a/orttraining/orttraining/python/training/ort_triton/_cache.py +++ b/orttraining/orttraining/python/training/ort_triton/_cache.py @@ -9,6 +9,7 @@ import getpass import hashlib import os +import sys import tempfile from types import ModuleType from typing import Tuple @@ -61,6 +62,7 @@ def load(cls, source_code) -> ModuleType: mod.__file__ = path mod.key = key exec(code, mod.__dict__, mod.__dict__) + sys.modules[mod.__name__] = mod # another thread might set this first cls.cache.setdefault(key, mod) return cls.cache[key] diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index e104ea13c59a3..14bc2779aa05b 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -6,11 +6,13 @@ import functools import json import os +import re import sys from types import ModuleType from typing import List, Tuple, Union import onnx +from onnx import ModelProto from torch._C import _from_dlpack from torch.utils.dlpack import to_dlpack @@ -41,18 +43,39 @@ class _ShapeCache: """ cache = dict() # noqa: RUF012 + symbolic_shape_hint = None + min_symbolic_shape = 0 clear = staticmethod(cache.clear) @classmethod - def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[int, str]]]: + def set_symbolic_shape_hint(cls, symbolic_shape_hint_config): + for k, v in symbolic_shape_hint_config.items(): + if k == "*": + cls.min_symbolic_shape = v + else: + if cls.symbolic_shape_hint is None: + cls.symbolic_shape_hint = dict() + cls.symbolic_shape_hint[k] = v + + @classmethod + def get_shape(cls, onnx_key: int, model: ModelProto, shapes: List[List[int]]) -> List[List[Union[int, str]]]: if onnx_key not in cls.cache: + if cls.symbolic_shape_hint is not None: + for i, input in enumerate(model.graph.input): + if input.type.tensor_type.HasField("shape"): + for j, dim in enumerate(input.type.tensor_type.shape.dim): + if dim.dim_param: + for k, v in cls.symbolic_shape_hint.items(): + if re.fullmatch(k, dim.dim_param): + shapes[i][j] = f"i{i}_dim{j}_{v}" + break cls.cache[onnx_key] = shapes else: changed = False for i, shape in enumerate(shapes): for j, dim in enumerate(shape): - if dim != cls.cache[onnx_key][i][j] and isinstance(cls.cache[onnx_key][i][j], int): - max_dim = max(dim, cls.cache[onnx_key][i][j]) + if isinstance(cls.cache[onnx_key][i][j], int) and dim != cls.cache[onnx_key][i][j]: + max_dim = max(dim, cls.cache[onnx_key][i][j], cls.min_symbolic_shape) shape[j] = f"i{i}_dim{j}_{next_power_of_2(max_dim)}" changed = True elif isinstance(cls.cache[onnx_key][i][j], str): @@ -67,13 +90,12 @@ def get_shape(cls, onnx_key: int, shapes: List[List[int]]) -> List[List[Union[in return cls.cache[onnx_key] -def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: +def _gen_key(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> int: # pylint: disable=unused-argument return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") -def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: - model = onnx.load_model_from_string(onnx_str) +def _gen_module(onnx_key: int, model: ModelProto, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: sorted_graph = SortedGraph(model, [parse_shape(shape) for shape in shapes]) if _DEBUG_MODE: os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) @@ -96,14 +118,28 @@ def get_config() -> str: "scalar": only related scalar initializers will be added to subgraphs. "all": all related initializers will be added to subgraphs. The min_nodes is used to control the minimum number of non-no-op nodes in a subgraph. + User can also specify symbolic_shape_hint in the config, which is a dict to control the symbolic shape hint. + Each entry is a regex pattern to match the dim_param in ONNX model and the value is the power of 2 for the symbolic + shape. Each dim_param will be replaced by i{input_index}_dim{dim_index}_{power_of_2} in the symbolic shape. """ + config = dict() config_file = os.getenv("ORTMODULE_TRITON_CONFIG_FILE", "") if config_file and os.path.exists(config_file): with open(config_file, encoding="UTF-8") as f: - return f.read() + config = json.load(f) + + if "ops" not in config: + config["ops"] = get_supported_ops() + if "initializer" not in config: + config["initializer"] = "scalar" + if "min_nodes" not in config: + config["min_nodes"] = 2 + + if "symbolic_shape_hint" in config and len(config["symbolic_shape_hint"]) > 0: + _ShapeCache.set_symbolic_shape_hint(config["symbolic_shape_hint"]) + del config["symbolic_shape_hint"] - config = {"ops": get_supported_ops(), "initializer": "scalar", "min_nodes": 2} return json.dumps(config) @@ -136,8 +172,9 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): assert all(tensor is not None for tensor in tensors) torch_tensors = [_from_dlpack(tensor) for tensor in tensors] concrete_shapes = [list(tensor.size()) for tensor in torch_tensors] - shapes = _ShapeCache.get_shape(onnx_key, concrete_shapes) - func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, onnx_str, shapes) + model = onnx.load_model_from_string(onnx_str) + shapes = _ShapeCache.get_shape(onnx_key, model, concrete_shapes) + func_name, mod = ModuleCache.load(_gen_key, _gen_module, onnx_key, model, shapes) func = getattr(mod, func_name) output = func(*torch_tensors) if isinstance(output, tuple):