Source code for torch_tensorrt._Device
-from __future__ import annotations
+from __future__ import annotations
-import logging
-import sys
-from typing import Any, Optional, Tuple
+import logging
+import sys
+from typing import Any, Optional, Tuple
if sys.version_info >= (3, 11):
- from typing import Self
+ from typing import Self
else:
- from typing_extensions import Self
+ from typing_extensions import Self
-import torch
-from torch_tensorrt._enums import DeviceType
-from torch_tensorrt._features import needs_torch_tensorrt_runtime
+import torch
+from torch_tensorrt._enums import DeviceType
+from torch_tensorrt._features import needs_torch_tensorrt_runtime
-import tensorrt as trt
+import tensorrt as trt
-[docs]class Device(object):
+[docs]class Device(object):
"""
Defines a device that can be used to specify target devices for engines
@@ -505,7 +505,7 @@ Source code for torch_tensorrt._Device
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
)
-[docs] def __init__(self, *args: Any, **kwargs: Any):
+[docs] def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
@@ -577,7 +577,7 @@ Source code for torch_tensorrt._Device
if isinstance(kwargs["device_type"], trt.DeviceType):
self.device_type = DeviceType._from(kwargs["device_type"])
- def __str__(self) -> str:
+ def __str__(self) -> str:
suffix = (
")"
if self.device_type == DeviceType.GPU
@@ -586,11 +586,11 @@ Source code for torch_tensorrt._Device
dev_str: str = f"Device(type={self.device_type}, gpu_id={self.gpu_id}{suffix}"
return dev_str
- def __repr__(self) -> str:
+ def __repr__(self) -> str:
return self.__str__()
@classmethod
- def _from(cls, d: Optional[Self | torch.device | str]) -> Device:
+ def _from(cls, d: Optional[Self | torch.device | str]) -> Device:
"""Cast a device-type to torch_tensorrt.Device
Returns the corresponding torch_tensorrt.Device
@@ -610,16 +610,16 @@ Source code for torch_tensorrt._Device
return cls(d)
@classmethod
- def _from_torch_device(cls, torch_dev: torch.device) -> Device:
+ def _from_torch_device(cls, torch_dev: torch.device) -> Device:
return cls._from(torch_dev)
@classmethod
- def _current_device(cls) -> Device:
+ def _current_device(cls) -> Device:
dev_id = torch.cuda.current_device()
return cls(gpu_id=dev_id)
@staticmethod
- def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]:
+ def _parse_device_str(s: str) -> Tuple[trt.DeviceType, int]:
s = s.lower()
spec = s.split(":")
if spec[0] == "gpu" or spec[0] == "cuda":
@@ -629,7 +629,7 @@ Source code for torch_tensorrt._Device
else:
raise ValueError(f"Unknown device type {spec[0]}")
- def to(self, t: type) -> torch.device:
+ def to(self, t: type) -> torch.device:
if t == torch.device:
if self.gpu_id != -1:
return torch.device(self.gpu_id)
@@ -639,7 +639,7 @@ Source code for torch_tensorrt._Device
raise TypeError("Unsupported target type for device conversion")
@needs_torch_tensorrt_runtime
- def _to_serialized_rt_device(self) -> str:
+ def _to_serialized_rt_device(self) -> str:
delim = torch.ops.tensorrt.SERIALIZED_RT_DEVICE_DELIM()[0]
dev_info = torch.cuda.get_device_properties(self.gpu_id)
rt_info = [
diff --git a/docs/_modules/torch_tensorrt/_Input.html b/docs/_modules/torch_tensorrt/_Input.html
index 55ee388d6e..1280bf4851 100644
--- a/docs/_modules/torch_tensorrt/_Input.html
+++ b/docs/_modules/torch_tensorrt/_Input.html
@@ -9,7 +9,7 @@
- torch_tensorrt._Input — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt._Input — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,16 +467,16 @@
Source code for torch_tensorrt._Input
-from __future__ import annotations
+from __future__ import annotations
-from enum import Enum
-from typing import Any, Dict, List, Optional, Sequence, Tuple
+from enum import Enum
+from typing import Any, Dict, List, Optional, Sequence, Tuple
-import torch
-from torch_tensorrt._enums import dtype, memory_format
+import torch
+from torch_tensorrt._enums import dtype, memory_format
-[docs]class Input(object):
+[docs]class Input(object):
"""
Defines an input to a module in terms of expected shape, data type and tensor format.
@@ -493,7 +493,7 @@ Source code for torch_tensorrt._Input
format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
"""
- class _ShapeMode(Enum):
+ class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
@@ -518,7 +518,7 @@ Source code for torch_tensorrt._Input
name: str = ""
is_shape_tensor: bool = False
-[docs] def __init__(self, *args: Any, **kwargs: Any) -> None:
+[docs] def __init__(self, *args: Any, **kwargs: Any) -> None:
"""__init__ Method for torch_tensorrt.Input
Input accepts one of a few construction patterns
@@ -659,7 +659,7 @@ Source code for torch_tensorrt._Input
if "name" in kwargs:
self.name = kwargs["name"]
- def __str__(self) -> str:
+ def __str__(self) -> str:
if self.shape_mode == Input._ShapeMode.STATIC:
return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format(
self.shape,
@@ -686,11 +686,11 @@ Source code for torch_tensorrt._Input
else:
raise RuntimeError("Unknown input shape mode")
- def __repr__(self) -> str:
+ def __repr__(self) -> str:
return self.__str__()
@staticmethod
- def equivalent_spec(a: Input, b: Input) -> bool:
+ def equivalent_spec(a: Input, b: Input) -> bool:
if a.shape_mode != b.shape_mode:
return False
@@ -718,7 +718,7 @@ Source code for torch_tensorrt._Input
return all(checks)
@staticmethod
- def _supported_input_size_type(input_size: Any) -> bool:
+ def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
return True
elif isinstance(input_size, tuple):
@@ -729,7 +729,7 @@ Source code for torch_tensorrt._Input
return False
@staticmethod
- def _parse_tensor_domain(
+ def _parse_tensor_domain(
domain: Optional[Tuple[float, float]]
) -> Tuple[float, float]:
"""
@@ -777,7 +777,7 @@ Source code for torch_tensorrt._Input
return result_domain
[docs] @classmethod
- def from_tensor(
+ def from_tensor(
cls, t: torch.Tensor, disable_memory_format_check: bool = False
) -> "Input":
"""
@@ -809,7 +809,7 @@ Source code for torch_tensorrt._Input
return cls(shape=t.shape, dtype=t.dtype, format=frmt, torch_tensor=t)
[docs] @classmethod
- def from_tensors(
+ def from_tensors(
cls, ts: Sequence[torch.Tensor], disable_memory_format_check: bool = False
) -> List["Input"]:
"""
@@ -830,7 +830,7 @@ Source code for torch_tensorrt._Input
for t in ts
]
-[docs] def example_tensor(
+[docs] def example_tensor(
self, optimization_profile_field: Optional[str] = None
) -> torch.Tensor:
"""
diff --git a/docs/_modules/torch_tensorrt/_compile.html b/docs/_modules/torch_tensorrt/_compile.html
index 410711569a..492132c1ef 100644
--- a/docs/_modules/torch_tensorrt/_compile.html
+++ b/docs/_modules/torch_tensorrt/_compile.html
@@ -9,7 +9,7 @@
- torch_tensorrt._compile — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt._compile — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,51 +467,51 @@
Source code for torch_tensorrt._compile
-from __future__ import annotations
-
-import collections.abc
-import logging
-import platform
-from enum import Enum
-from typing import Any, Callable, List, Optional, Sequence, Set
-
-import torch
-import torch.fx
-from torch_tensorrt._enums import dtype
-from torch_tensorrt._features import ENABLED_FEATURES
-from torch_tensorrt._Input import Input
-from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
+from __future__ import annotations
+
+import collections.abc
+import logging
+import platform
+from enum import Enum
+from typing import Any, Callable, List, Optional, Sequence, Set
+
+import torch
+import torch.fx
+from torch_tensorrt._enums import dtype
+from torch_tensorrt._features import ENABLED_FEATURES
+from torch_tensorrt._Input import Input
+from torch_tensorrt.dynamo import _defaults
+from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)
-from torch_tensorrt.fx import InputTensorSpec
-from torch_tensorrt.fx.lower import compile as fx_compile
-from torch_tensorrt.fx.utils import LowerPrecision
-from typing_extensions import TypeGuard
+from torch_tensorrt.fx import InputTensorSpec
+from torch_tensorrt.fx.lower import compile as fx_compile
+from torch_tensorrt.fx.utils import LowerPrecision
+from typing_extensions import TypeGuard
if ENABLED_FEATURES.torchscript_frontend:
- import torch_tensorrt.ts
- from torch_tensorrt.ts._compiler import compile as torchscript_compile
- from torch_tensorrt.ts._compiler import (
+ import torch_tensorrt.ts
+ from torch_tensorrt.ts._compiler import compile as torchscript_compile
+ from torch_tensorrt.ts._compiler import (
convert_method_to_trt_engine as ts_convert_method_to_trt_engine,
)
if ENABLED_FEATURES.dynamo_frontend:
- from torch.export import ExportedProgram
- from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
- from torch_tensorrt.dynamo._compiler import (
+ from torch.export import ExportedProgram
+ from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
+ from torch_tensorrt.dynamo._compiler import (
convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
)
- from torch_tensorrt.dynamo._compiler import (
+ from torch_tensorrt.dynamo._compiler import (
cross_compile_for_windows as dynamo_cross_compile_for_windows,
)
- from torch_tensorrt.dynamo._compiler import (
+ from torch_tensorrt.dynamo._compiler import (
load_cross_compiled_exported_program as dynamo_load_cross_compiled_exported_program,
)
- from torch_tensorrt.dynamo._compiler import (
+ from torch_tensorrt.dynamo._compiler import (
save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program,
)
- from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
+ from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
logger = logging.getLogger(__name__)
@@ -525,19 +525,19 @@ Source code for torch_tensorrt._compile
]
-def _non_fx_input_interface(
+def _non_fx_input_interface(
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
) -> TypeGuard[List[Input | torch.Tensor]]:
return all(isinstance(i, (torch.Tensor, Input)) for i in inputs)
-def _fx_input_interface(
+def _fx_input_interface(
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
) -> TypeGuard[List[InputTensorSpec | torch.Tensor]]:
return all(isinstance(i, (torch.Tensor, InputTensorSpec)) for i in inputs)
-class _IRType(Enum):
+class _IRType(Enum):
"""Enum to determine the type of IR selected for model compilation"""
ts = 0
@@ -547,7 +547,7 @@ Source code for torch_tensorrt._compile
exported_program = 4
-class _ModuleType(Enum):
+class _ModuleType(Enum):
"""Enum to determine the type of model provided as input"""
nn = 0
@@ -556,7 +556,7 @@ Source code for torch_tensorrt._compile
ep = 3
-def _parse_module_type(module: Any) -> _ModuleType:
+def _parse_module_type(module: Any) -> _ModuleType:
if any(
isinstance(module, t)
for t in [torch.jit.ScriptModule, torch.jit.ScriptFunction]
@@ -572,7 +572,7 @@ Source code for torch_tensorrt._compile
raise RuntimeError("Module is an unknown format")
-def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType:
+def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType:
module_is_tsable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.ts])
module_is_fxable = any(module_type == t for t in [_ModuleType.nn, _ModuleType.fx])
module_is_exportable = module_type == _ModuleType.ep
@@ -633,7 +633,7 @@ Source code for torch_tensorrt._compile
raise ValueError("Unknown ir was requested")
-[docs]def compile(
+[docs]def compile(
module: Any,
ir: str = "default",
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
@@ -743,7 +743,7 @@ Source code for torch_tensorrt._compile
if kwarg_inputs is None:
kwarg_inputs = {}
- from torch_tensorrt.dynamo.utils import prepare_inputs
+ from torch_tensorrt.dynamo.utils import prepare_inputs
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore
@@ -770,7 +770,7 @@ Source code for torch_tensorrt._compile
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
-def cross_compile_for_windows(
+def cross_compile_for_windows(
module: torch.nn.Module,
file_path: str,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
@@ -843,7 +843,7 @@ Source code for torch_tensorrt._compile
if kwarg_inputs is None:
kwarg_inputs = {}
- from torch_tensorrt.dynamo.utils import prepare_inputs
+ from torch_tensorrt.dynamo.utils import prepare_inputs
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore
@@ -869,13 +869,13 @@ Source code for torch_tensorrt._compile
logger.debug("successfully compiled and saved the module for windows")
-def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
+def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
"""
Returns a boxed model which is the output of torch.compile.
This does not compile the model to TRT. Execute this model on
sample inputs to compile the model to TRT.
"""
- from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
+ from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
# TODO: Remove dynamic=False when SymInt Dynamic shape support is ready
boxed_fn = torch.compile(
@@ -885,7 +885,7 @@ Source code for torch_tensorrt._compile
return boxed_fn
-[docs]def convert_method_to_trt_engine(
+[docs]def convert_method_to_trt_engine(
module: Any,
method_name: str = "forward",
inputs: Optional[Sequence[Input | torch.Tensor | InputTensorSpec]] = None,
@@ -966,7 +966,7 @@ Source code for torch_tensorrt._compile
if kwarg_inputs is None:
kwarg_inputs = {}
- from torch_tensorrt.dynamo.utils import prepare_inputs
+ from torch_tensorrt.dynamo.utils import prepare_inputs
if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore
@@ -994,7 +994,7 @@ Source code for torch_tensorrt._compile
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
-def load_cross_compiled_exported_program(file_path: str = "") -> Any:
+def load_cross_compiled_exported_program(file_path: str = "") -> Any:
"""
Load an ExportedProgram file in Windows which was previously cross compiled in Linux
@@ -1007,7 +1007,7 @@ Source code for torch_tensorrt._compile
return dynamo_load_cross_compiled_exported_program(file_path)
-[docs]def load(file_path: str = "") -> Any:
+[docs]def load(file_path: str = "") -> Any:
"""
Load either a Torchscript model or ExportedProgram.
@@ -1044,7 +1044,7 @@ Source code for torch_tensorrt._compile
)
-[docs]def save(
+[docs]def save(
module: Any,
file_path: str = "",
*,
@@ -1131,7 +1131,7 @@ Source code for torch_tensorrt._compile
torch.jit.save(module_ts, file_path)
else:
if not retrace:
- from torch_tensorrt.dynamo._exporter import export
+ from torch_tensorrt.dynamo._exporter import export
if arg_inputs is not None:
logger.warning(
diff --git a/docs/_modules/torch_tensorrt/_enums.html b/docs/_modules/torch_tensorrt/_enums.html
index aa5c83ac26..fa5fceffeb 100644
--- a/docs/_modules/torch_tensorrt/_enums.html
+++ b/docs/_modules/torch_tensorrt/_enums.html
@@ -9,7 +9,7 @@
- torch_tensorrt._enums — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt._enums — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,19 +467,19 @@
Source code for torch_tensorrt._enums
-from __future__ import annotations
+from __future__ import annotations
-import logging
-from enum import Enum, auto
-from typing import Any, Optional, Type, Union
+import logging
+from enum import Enum, auto
+from typing import Any, Optional, Type, Union
-import numpy as np
-import tensorrt as trt
-import torch
-from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
+import numpy as np
+import tensorrt as trt
+import torch
+from torch_tensorrt._features import ENABLED_FEATURES, needs_torch_tensorrt_runtime
-[docs]class dtype(Enum):
+[docs]class dtype(Enum):
"""Enum to describe data types to Torch-TensorRT, has compatibility with torch, tensorrt and numpy dtypes"""
# Supported types in Torch-TensorRT
@@ -575,7 +575,7 @@ Source code for torch_tensorrt._enums
bfloat16 = bf16
@staticmethod
- def _is_np_obj(t: Any) -> bool:
+ def _is_np_obj(t: Any) -> bool:
if isinstance(t, np.dtype):
return True
elif isinstance(t, type):
@@ -584,7 +584,7 @@ Source code for torch_tensorrt._enums
return False
@classmethod
- def _from(
+ def _from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype, type],
use_default: bool = False,
@@ -710,7 +710,7 @@ Source code for torch_tensorrt._enums
return t
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if isinstance(t, _C.dtype):
if t == _C.dtype.long:
@@ -739,7 +739,7 @@ Source code for torch_tensorrt._enums
)
[docs] @classmethod
- def try_from(
+ def try_from(
cls,
t: Union[torch.dtype, trt.DataType, np.dtype, dtype],
use_default: bool = False,
@@ -779,7 +779,7 @@ Source code for torch_tensorrt._enums
)
return None
-[docs] def to(
+[docs] def to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool = False,
@@ -898,7 +898,7 @@ Source code for torch_tensorrt._enums
return self
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if t == _C.dtype:
if self == dtype.i64:
@@ -926,7 +926,7 @@ Source code for torch_tensorrt._enums
f"Provided unsupported destination type for dtype conversion {t}"
)
-[docs] def try_to(
+[docs] def try_to(
self,
t: Union[Type[torch.dtype], Type[trt.DataType], Type[np.dtype], Type[dtype]],
use_default: bool,
@@ -965,11 +965,11 @@ Source code for torch_tensorrt._enums
)
return None
- def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool:
+ def __eq__(self, other: Union[torch.dtype, trt.DataType, np.dtype, dtype]) -> bool:
other_ = dtype._from(other)
return bool(self.value == other_.value)
- def __hash__(self) -> int:
+ def __hash__(self) -> int:
return hash(self.value)
# Putting aliases here that mess with mypy
@@ -977,7 +977,7 @@ Source code for torch_tensorrt._enums
int = i32
-[docs]class memory_format(Enum):
+[docs]class memory_format(Enum):
""""""
# TensorRT supported memory layouts
@@ -1109,7 +1109,7 @@ Source code for torch_tensorrt._enums
channels_last_3d = dhwc
@classmethod
- def _from(
+ def _from(
cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> memory_format:
"""Create a Torch-TensorRT memory format enum from another library memory format enum.
@@ -1185,7 +1185,7 @@ Source code for torch_tensorrt._enums
return f
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if isinstance(f, _C.TensorFormat):
if f == _C.TensorFormat.contiguous:
@@ -1200,7 +1200,7 @@ Source code for torch_tensorrt._enums
raise TypeError("Provided unsupported source type for memory_format conversion")
[docs] @classmethod
- def try_from(
+ def try_from(
cls, f: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> Optional[memory_format]:
"""Create a Torch-TensorRT memory format enum from another library memory format enum.
@@ -1233,7 +1233,7 @@ Source code for torch_tensorrt._enums
)
return None
-[docs] def to(
+[docs] def to(
self,
t: Union[
Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format]
@@ -1308,7 +1308,7 @@ Source code for torch_tensorrt._enums
return self
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if t == _C.TensorFormat:
if self == memory_format.contiguous:
@@ -1324,7 +1324,7 @@ Source code for torch_tensorrt._enums
"Provided unsupported destination type for memory format conversion"
)
-[docs] def try_to(
+[docs] def try_to(
self,
t: Union[
Type[torch.memory_format], Type[trt.TensorFormat], Type[memory_format]
@@ -1359,17 +1359,17 @@ Source code for torch_tensorrt._enums
)
return None
- def __eq__(
+ def __eq__(
self, other: Union[torch.memory_format, trt.TensorFormat, memory_format]
) -> bool:
other_ = memory_format._from(other)
return self.value == other_.value
- def __hash__(self) -> int:
+ def __hash__(self) -> int:
return hash(self.value)
-[docs]class DeviceType(Enum):
+[docs]class DeviceType(Enum):
"""Type of device TensorRT will target"""
UNKNOWN = auto()
@@ -1394,7 +1394,7 @@ Source code for torch_tensorrt._enums
"""
@classmethod
- def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType:
+ def _from(cls, d: Union[trt.DeviceType, DeviceType]) -> DeviceType:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
@@ -1433,7 +1433,7 @@ Source code for torch_tensorrt._enums
return d
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if isinstance(d, _C.DeviceType):
if d == _C.DeviceType.GPU:
@@ -1448,7 +1448,7 @@ Source code for torch_tensorrt._enums
raise TypeError("Provided unsupported source type for DeviceType conversion")
[docs] @classmethod
- def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]:
+ def try_from(cls, d: Union[trt.DeviceType, DeviceType]) -> Optional[DeviceType]:
"""Create a Torch-TensorRT device type enum from a TensorRT device type enum.
Takes a device type enum from tensorrt and create a ``torch_tensorrt.DeviceType``.
@@ -1480,7 +1480,7 @@ Source code for torch_tensorrt._enums
)
return None
-[docs] def to(
+[docs] def to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
@@ -1526,7 +1526,7 @@ Source code for torch_tensorrt._enums
return self
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if t == _C.DeviceType:
if self == DeviceType.GPU:
@@ -1542,7 +1542,7 @@ Source code for torch_tensorrt._enums
"Provided unsupported destination type for device type conversion"
)
-[docs] def try_to(
+[docs] def try_to(
self,
t: Union[Type[trt.DeviceType], Type[DeviceType]],
use_default: bool = False,
@@ -1575,15 +1575,15 @@ Source code for torch_tensorrt._enums
)
return None
- def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool:
+ def __eq__(self, other: Union[trt.DeviceType, DeviceType]) -> bool:
other_ = DeviceType._from(other)
return bool(self.value == other_.value)
- def __hash__(self) -> int:
+ def __hash__(self) -> int:
return hash(self.value)
-[docs]class EngineCapability(Enum):
+[docs]class EngineCapability(Enum):
"""
EngineCapability determines the restrictions of a network during build time and what runtime it targets.
"""
@@ -1610,7 +1610,7 @@ Source code for torch_tensorrt._enums
"""
@classmethod
- def _from(
+ def _from(
cls, c: Union[trt.EngineCapability, EngineCapability]
) -> EngineCapability:
"""Create a Torch-TensorRT Engine capability enum from a TensorRT Engine capability enum.
@@ -1651,7 +1651,7 @@ Source code for torch_tensorrt._enums
return c
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if isinstance(c, _C.EngineCapability):
if c == _C.EngineCapability.STANDARD:
@@ -1668,7 +1668,7 @@ Source code for torch_tensorrt._enums
)
[docs] @classmethod
- def try_from(
+ def try_from(
c: Union[trt.EngineCapability, EngineCapability]
) -> Optional[EngineCapability]:
"""Create a Torch-TensorRT engine capability enum from a TensorRT engine capability enum.
@@ -1702,7 +1702,7 @@ Source code for torch_tensorrt._enums
)
return None
-[docs] def to(
+[docs] def to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Union[trt.EngineCapability, EngineCapability]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
@@ -1743,7 +1743,7 @@ Source code for torch_tensorrt._enums
return self
elif ENABLED_FEATURES.torchscript_frontend:
- from torch_tensorrt import _C
+ from torch_tensorrt import _C
if t == _C.EngineCapability:
if self == EngineCapability.STANDARD:
@@ -1759,7 +1759,7 @@ Source code for torch_tensorrt._enums
"Provided unsupported destination type for engine capability type conversion"
)
-[docs] def try_to(
+[docs] def try_to(
self, t: Union[Type[trt.EngineCapability], Type[EngineCapability]]
) -> Optional[Union[trt.EngineCapability, EngineCapability]]:
"""Convert ``EngineCapability`` into the equivalent type in tensorrt
@@ -1790,15 +1790,15 @@ Source code for torch_tensorrt._enums
)
return None
- def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool:
+ def __eq__(self, other: Union[trt.EngineCapability, EngineCapability]) -> bool:
other_ = EngineCapability._from(other)
return bool(self.value == other_.value)
- def __hash__(self) -> int:
+ def __hash__(self) -> int:
return hash(self.value)
-class Platform(Enum):
+class Platform(Enum):
"""
Specifies a target OS and CPU architecture that a Torch-TensorRT program targets
"""
@@ -1827,14 +1827,14 @@ Source code for torch_tensorrt._enums
UNKNOWN = auto()
@classmethod
- def current_platform(cls) -> Platform:
+ def current_platform(cls) -> Platform:
"""
Returns an enum for the current platform Torch-TensorRT is running on
Returns:
Platform: Current platform
"""
- import platform
+ import platform
if platform.system().lower().startswith("linux"):
# linux
@@ -1850,11 +1850,11 @@ Source code for torch_tensorrt._enums
return Platform.UNKNOWN
- def __str__(self) -> str:
+ def __str__(self) -> str:
return str(self.name)
@needs_torch_tensorrt_runtime # type: ignore
- def _to_serialized_rt_platform(self) -> str:
+ def _to_serialized_rt_platform(self) -> str:
val: str = torch.ops.tensorrt._platform_unknown()
if self == Platform.LINUX_X86_64:
diff --git a/docs/_modules/torch_tensorrt/dynamo/_compiler.html b/docs/_modules/torch_tensorrt/dynamo/_compiler.html
index f8686f6ce4..396e717d3a 100644
--- a/docs/_modules/torch_tensorrt/dynamo/_compiler.html
+++ b/docs/_modules/torch_tensorrt/dynamo/_compiler.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo._compiler — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo._compiler — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,45 +467,45 @@
Source code for torch_tensorrt.dynamo._compiler
-from __future__ import annotations
-
-import collections.abc
-import logging
-import platform
-import warnings
-from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
-
-import torch
-from torch.export import ExportedProgram
-from torch.fx.node import Target
-from torch_tensorrt._Device import Device
-from torch_tensorrt._enums import EngineCapability, dtype
-from torch_tensorrt._Input import Input
-from torch_tensorrt.dynamo import _defaults, partitioning
-from torch_tensorrt.dynamo._DryRunTracker import (
+from __future__ import annotations
+
+import collections.abc
+import logging
+import platform
+import warnings
+from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
+
+import torch
+from torch.export import ExportedProgram
+from torch.fx.node import Target
+from torch_tensorrt._Device import Device
+from torch_tensorrt._enums import EngineCapability, dtype
+from torch_tensorrt._Input import Input
+from torch_tensorrt.dynamo import _defaults, partitioning
+from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
PerSubgraphData,
dryrun_stats_display,
parse_non_trt_nodes,
)
-from torch_tensorrt.dynamo._engine_cache import BaseEngineCache, DiskEngineCache
-from torch_tensorrt.dynamo._exporter import replace_execute_engine_no_op_node
-from torch_tensorrt.dynamo.conversion import (
+from torch_tensorrt.dynamo._engine_cache import BaseEngineCache, DiskEngineCache
+from torch_tensorrt.dynamo._exporter import replace_execute_engine_no_op_node
+from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
UnsupportedOperatorException,
convert_module,
interpret_module_to_result,
repair_double_inputs,
)
-from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
+from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
-from torch_tensorrt.dynamo.lowering import (
+from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
pre_export_lowering,
)
-from torch_tensorrt.dynamo.utils import (
+from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
parse_graph_io,
@@ -518,7 +518,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
logger = logging.getLogger(__name__)
-def cross_compile_for_windows(
+def cross_compile_for_windows(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
*,
@@ -835,7 +835,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
return trt_gm
-[docs]def compile(
+[docs]def compile(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
*,
@@ -1153,7 +1153,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
return trt_gm
-def compile_module(
+def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
@@ -1210,7 +1210,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)
- def contains_metadata(gm: torch.fx.GraphModule) -> bool:
+ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
for node in gm.graph.nodes:
if node.op != "output" and (not node.meta) and "val" not in node.meta:
logger.warning(
@@ -1379,7 +1379,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
return partitioned_module
-def convert_exported_program_to_serialized_trt_engine(
+def convert_exported_program_to_serialized_trt_engine(
exported_program: ExportedProgram,
inputs: Optional[Sequence[Sequence[Any]]] = None,
*,
@@ -1641,7 +1641,7 @@ Source code for torch_tensorrt.dynamo._compiler
<
return serialized_engine
-def save_cross_compiled_exported_program(
+def save_cross_compiled_exported_program(
gm: torch.fx.GraphModule,
file_path: str,
) -> None:
@@ -1655,14 +1655,14 @@ Source code for torch_tensorrt.dynamo._compiler
<
if not file_path:
raise ValueError("File path cannot be empty. Please provide a valid file path")
- from torch_tensorrt.dynamo._exporter import export
+ from torch_tensorrt.dynamo._exporter import export
exp_program = export(gm, cross_compile_flag=True)
torch.export.save(exp_program, file_path)
logger.debug(f"successfully saved the module for windows at {file_path}")
-def load_cross_compiled_exported_program(file_path: str = "") -> Any:
+def load_cross_compiled_exported_program(file_path: str = "") -> Any:
"""
Load an ExportedProgram file in Windows which was previously cross compiled in Linux
diff --git a/docs/_modules/torch_tensorrt/dynamo/_exporter.html b/docs/_modules/torch_tensorrt/dynamo/_exporter.html
index 2a1eb9dfa9..d3540b7fe2 100644
--- a/docs/_modules/torch_tensorrt/dynamo/_exporter.html
+++ b/docs/_modules/torch_tensorrt/dynamo/_exporter.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo._exporter — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo._exporter — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,16 +467,16 @@
Source code for torch_tensorrt.dynamo._exporter
-import base64
-import copy
-import operator
-from typing import Any, Dict, Optional, Sequence, Tuple, cast
-
-import torch
-from torch._guards import detect_fake_mode
-from torch._subclasses.fake_tensor import FakeTensor
-from torch.export import ExportedProgram, ExportGraphSignature
-from torch.export.exported_program import (
+import base64
+import copy
+import operator
+from typing import Any, Dict, Optional, Sequence, Tuple, cast
+
+import torch
+from torch._guards import detect_fake_mode
+from torch._subclasses.fake_tensor import FakeTensor
+from torch.export import ExportedProgram, ExportGraphSignature
+from torch.export.exported_program import (
CustomObjArgument,
InputKind,
InputSpec,
@@ -486,10 +486,10 @@ Source code for torch_tensorrt.dynamo._exporter
<
OutputSpec,
TensorArgument,
)
-from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX
+from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ENGINE_IDX, NAME_IDX
-[docs]def export(
+[docs]def export(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
) -> ExportedProgram:
@@ -505,7 +505,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return exp_program
-def transform(
+def transform(
gm: torch.fx.GraphModule,
cross_compile_flag: Optional[bool] = False,
) -> torch.fx.GraphModule:
@@ -539,7 +539,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return gm
-def lift(
+def lift(
gm: torch.fx.GraphModule, graph_signature: Any
) -> Tuple[torch.fx.GraphModule, ExportGraphSignature, Dict[str, Any], Dict[str, Any]]:
"""
@@ -661,7 +661,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return gm, graph_signature, state_dict, constants
-def get_duplicate_nodes(
+def get_duplicate_nodes(
gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule
) -> Tuple[Sequence[Any], Sequence[Any]]:
"""
@@ -684,7 +684,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return submodule_duplicate_inputs, gm_duplicate_inputs
-def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
+def inline_torch_modules(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Inline a submodule within the parent graph (gm). All `call_module` nodes
should be replaced by their nodes in the submodule.
@@ -751,7 +751,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return gm
-def copy_submodule_attributes(
+def copy_submodule_attributes(
gm: torch.fx.GraphModule, submodule: torch.fx.GraphModule, submodule_name: str
) -> None:
"""
@@ -762,7 +762,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
_assign_attr does exactly that. It creates a module for eg: conv, adds an attribute weight
to it and adds this conv module as an attribute to parent gm.
"""
- from torch.export.unflatten import _assign_attr, _AttrKind
+ from torch.export.unflatten import _assign_attr, _AttrKind
for key, value in submodule.named_parameters():
_assign_attr(value, gm, key, _AttrKind.PARAMETER)
@@ -771,7 +771,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
_assign_attr(value, gm, key, _AttrKind.BUFFER)
-def create_trt_exp_program(
+def create_trt_exp_program(
gm: torch.fx.GraphModule,
) -> ExportedProgram:
"""Creates a new Exported Program. This function takes an torch.fx.GraphModule which has TRT engines
@@ -825,7 +825,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return trt_exp_program
-def inline_trt_modules(
+def inline_trt_modules(
gm: torch.fx.GraphModule, cross_compile_flag: Optional[bool] = False
) -> torch.fx.GraphModule:
"""
@@ -901,7 +901,7 @@ Source code for torch_tensorrt.dynamo._exporter
<
return gm
-def replace_execute_engine_no_op_node(
+def replace_execute_engine_no_op_node(
exp_program: ExportedProgram,
) -> ExportedProgram:
gm = exp_program.graph_module
diff --git a/docs/_modules/torch_tensorrt/dynamo/_refit.html b/docs/_modules/torch_tensorrt/dynamo/_refit.html
index baf31324f9..30b33c225a 100644
--- a/docs/_modules/torch_tensorrt/dynamo/_refit.html
+++ b/docs/_modules/torch_tensorrt/dynamo/_refit.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo._refit — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo._refit — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,42 +467,42 @@
Source code for torch_tensorrt.dynamo._refit
-from __future__ import annotations
-
-import collections.abc
-import copy
-import logging
-from typing import Any, List, Optional, Sequence, Tuple
-
-import numpy as np
-import tensorrt as trt
-import torch
-from torch.export import ExportedProgram
-from torch_tensorrt._enums import dtype
-from torch_tensorrt._Input import Input
-from torch_tensorrt.dynamo import partitioning
-from torch_tensorrt.dynamo._exporter import inline_torch_modules
-from torch_tensorrt.dynamo._settings import CompilationSettings
-from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
-from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
+from __future__ import annotations
+
+import collections.abc
+import copy
+import logging
+from typing import Any, List, Optional, Sequence, Tuple
+
+import numpy as np
+import tensorrt as trt
+import torch
+from torch.export import ExportedProgram
+from torch_tensorrt._enums import dtype
+from torch_tensorrt._Input import Input
+from torch_tensorrt.dynamo import partitioning
+from torch_tensorrt.dynamo._exporter import inline_torch_modules
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
+from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
)
-from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
-from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
-from torch_tensorrt.dynamo.lowering import (
+from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
+from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
+from torch_tensorrt.dynamo.lowering import (
get_decompositions,
post_lowering,
pre_export_lowering,
)
-from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
+from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import (
PythonTorchTensorRTModule,
)
-from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
+from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
ENGINE_IDX,
SERIALIZED_METADATA_IDX,
TorchTensorRTModule,
)
-from torch_tensorrt.dynamo.utils import (
+from torch_tensorrt.dynamo.utils import (
check_module_output,
get_model_device,
get_torch_inputs,
@@ -510,12 +510,12 @@ Source code for torch_tensorrt.dynamo._refit
to_torch_device,
to_torch_tensorrt_device,
)
-from torch_tensorrt.logging import TRT_LOGGER
+from torch_tensorrt.logging import TRT_LOGGER
logger = logging.getLogger(__name__)
-def construct_refit_mapping(
+def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
@@ -576,7 +576,7 @@ Source code for torch_tensorrt.dynamo._refit
return weight_map
-def construct_refit_mapping_from_weight_name_map(
+def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
) -> dict[Any, Any]:
engine_weight_map = {}
@@ -602,7 +602,7 @@ Source code for torch_tensorrt.dynamo._refit
return engine_weight_map
-def _refit_single_trt_engine_with_gm(
+def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
input_list: Sequence[Any],
@@ -680,7 +680,7 @@ Source code for torch_tensorrt.dynamo._refit
raise AssertionError("Refitting failed.")
-[docs]def refit_module_weights(
+[docs]def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
arg_inputs: Optional[Tuple[Any, ...]] = None,
@@ -979,10 +979,10 @@ Source code for torch_tensorrt.dynamo._refit
# Util functions -----------
-import base64
+import base64
-def get_engine_from_encoded_engine(
+def get_engine_from_encoded_engine(
encoded_engine: str, runtime: trt.Runtime
) -> trt.ICudaEngine:
serialized_engine = base64.b64decode(encoded_engine)
diff --git a/docs/_modules/torch_tensorrt/dynamo/_settings.html b/docs/_modules/torch_tensorrt/dynamo/_settings.html
index 6e06960379..ec40196697 100644
--- a/docs/_modules/torch_tensorrt/dynamo/_settings.html
+++ b/docs/_modules/torch_tensorrt/dynamo/_settings.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo._settings — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo._settings — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,13 +467,13 @@
Source code for torch_tensorrt.dynamo._settings
-from dataclasses import dataclass, field
-from typing import Collection, Optional, Set, Tuple, Union
+from dataclasses import dataclass, field
+from typing import Collection, Optional, Set, Tuple, Union
-from torch.fx.node import Target
-from torch_tensorrt._Device import Device
-from torch_tensorrt._enums import EngineCapability, dtype
-from torch_tensorrt.dynamo._defaults import (
+from torch.fx.node import Target
+from torch_tensorrt._Device import Device
+from torch_tensorrt._enums import EngineCapability, dtype
+from torch_tensorrt.dynamo._defaults import (
ASSUME_DYNAMIC_SHAPE_SUPPORT,
CACHE_BUILT_ENGINES,
DEBUG,
@@ -502,6 +502,7 @@ Source code for torch_tensorrt.dynamo._settings
<
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
+ USE_AOT_JOINT_EXPORT,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
@@ -513,7 +514,7 @@ Source code for torch_tensorrt.dynamo._settings
<
[docs]@dataclass
-class CompilationSettings:
+class CompilationSettings:
"""Compilation settings for Torch-TensorRT Dynamo Paths
Args:
@@ -560,6 +561,7 @@ Source code for torch_tensorrt.dynamo._settings
<
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
+ use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
"""
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -599,7 +601,8 @@ Source code for torch_tensorrt.dynamo._settings
<
strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
- enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
+ enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
+ use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
@@ -618,7 +621,7 @@ Source code for torch_tensorrt.dynamo._settings
<
)
-def settings_are_compatible(
+def settings_are_compatible(
set_a: CompilationSettings, set_b: CompilationSettings
) -> Tuple[bool, Set[str]]:
incompatible_settings: Set[str] = set()
diff --git a/docs/_modules/torch_tensorrt/dynamo/_tracer.html b/docs/_modules/torch_tensorrt/dynamo/_tracer.html
index a95673ad79..001d5fd78a 100644
--- a/docs/_modules/torch_tensorrt/dynamo/_tracer.html
+++ b/docs/_modules/torch_tensorrt/dynamo/_tracer.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo._tracer — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo._tracer — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,22 +467,22 @@
Source code for torch_tensorrt.dynamo._tracer
-from __future__ import annotations
+from __future__ import annotations
-import logging
-from inspect import signature
-from typing import Any, Optional, Tuple, Union
+import logging
+from inspect import signature
+from typing import Any, Optional, Tuple, Union
-import torch
-from torch.export import Dim, export
-from torch_tensorrt._Input import Input
-from torch_tensorrt.dynamo._defaults import DEBUG, default_device
-from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
+import torch
+from torch.export import Dim, export
+from torch_tensorrt._Input import Input
+from torch_tensorrt.dynamo._defaults import DEBUG, default_device
+from torch_tensorrt.dynamo.utils import get_torch_inputs, set_log_level, to_torch_device
logger = logging.getLogger(__name__)
-[docs]def trace(
+[docs]def trace(
mod: torch.nn.Module | torch.fx.GraphModule,
inputs: Optional[Tuple[Any, ...]] = None,
*,
@@ -559,7 +559,7 @@ Source code for torch_tensorrt.dynamo._tracer
return exp_program
-def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]:
+def get_dynamic_shapes_kwargs(inputs: Any) -> Union[dict[str, Any], list[Any]]:
if isinstance(inputs, dict):
dynamic_shapes_kwarg = {}
for k, v in inputs.items():
@@ -578,7 +578,7 @@ Source code for torch_tensorrt.dynamo._tracer
raise TypeError(f"Unknown type {type(inputs)}.")
-def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]:
+def get_dynamic_shapes_args(mod: torch.nn.Module, inputs: Any) -> dict[str, Any]:
# dynamic_shape is a dict and cannot work without keys. Here we use position argument name
# in forward function as the name
args = list(signature(mod.forward).parameters.keys())
@@ -588,7 +588,7 @@ Source code for torch_tensorrt.dynamo._tracer
return dynamic_shapes
-def get_dynamic_shapes(input: Input) -> dict[Any, Any]:
+def get_dynamic_shapes(input: Input) -> dict[Any, Any]:
if not isinstance(input, Input):
# If the input is torch.Tensor, no dynamic is needed. Return empty dict
return {}
diff --git a/docs/_modules/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.html b/docs/_modules/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.html
index 0dc38fbcc2..9e23e3bbc9 100644
--- a/docs/_modules/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.html
+++ b/docs/_modules/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,21 +467,21 @@
Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule
-import logging
-from copy import deepcopy
-from enum import Enum, auto
-from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union
-
-import numpy as np
-import torch
-from torch.fx.node import Target
-from torch_tensorrt._Device import Device
-from torch_tensorrt._enums import EngineCapability, dtype
-from torch_tensorrt.dynamo import _defaults
-from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
-from torch_tensorrt.dynamo._refit import refit_module_weights
-from torch_tensorrt.dynamo._settings import CompilationSettings
-from torch_tensorrt.dynamo.utils import (
+import logging
+from copy import deepcopy
+from enum import Enum, auto
+from typing import Any, Collection, Dict, Iterator, List, Optional, Set, Union
+
+import numpy as np
+import torch
+from torch.fx.node import Target
+from torch_tensorrt._Device import Device
+from torch_tensorrt._enums import EngineCapability, dtype
+from torch_tensorrt.dynamo import _defaults
+from torch_tensorrt.dynamo._compiler import compile as dynamo_compile
+from torch_tensorrt.dynamo._refit import refit_module_weights
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.utils import (
check_output_equal,
to_torch_device,
to_torch_tensorrt_device,
@@ -490,27 +490,27 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulelogger = logging.getLogger(__name__)
-class RefitFlag(Enum):
+class RefitFlag(Enum):
UNKNOWN = auto()
NEEDS_REFIT = auto()
NEEDS_RECOMPILE = auto()
LIVE = auto()
-class RefitState:
+class RefitState:
_state: RefitFlag = RefitFlag.NEEDS_RECOMPILE
- def set_state(self, state: RefitFlag) -> None:
+ def set_state(self, state: RefitFlag) -> None:
if isinstance(state, RefitFlag):
self._state = state
else:
raise ValueError(f"Invalid state: {state}")
- def get_state(self) -> RefitFlag:
+ def get_state(self) -> RefitFlag:
return self._state
-[docs]class MutableTorchTensorRTModule(object):
+[docs]class MutableTorchTensorRTModule(object):
"""
Initialize a MutableTorchTensorRTModule to seamlessly manipulate it like a regular PyTorch module.
All TensorRT compilation and refitting processes are handled automatically as you work with the module.
@@ -522,7 +522,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule Any modifications made to the MutableTorchTensorRTModule will be reflected in both the TensorRT graph module and the original PyTorch module.
"""
-[docs] def __init__(
+[docs] def __init__(
self,
pytorch_model: torch.nn.Module,
*,
@@ -672,22 +672,22 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule)
self.init_finished = True
- def store_state_dict_metadata(self) -> None:
+ def store_state_dict_metadata(self) -> None:
for k, v in self.original_model.state_dict().items():
self.state_dict_metadata[k] = v.shape
- def load_state_dict(
+ def load_state_dict(
self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False
) -> None:
self.refit_state.set_state(RefitFlag.NEEDS_REFIT)
self.original_model.load_state_dict(state_dict, strict=strict, assign=assign)
@staticmethod
- def _transform_state_dict(sd: Dict[str, Any]) -> Dict[str, torch.nn.Parameter]:
+ def _transform_state_dict(sd: Dict[str, Any]) -> Dict[str, torch.nn.Parameter]:
return {k: torch.nn.Parameter(v, requires_grad=False) for k, v in sd.items()}
- def update_refit_condition(self) -> None:
+ def update_refit_condition(self) -> None:
# 2-stage check to determine whether the module should be intact, refitted, or recompiled.
# Default refit
@@ -721,7 +721,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn
-[docs] def refit_gm(self) -> None:
+[docs] def refit_gm(self) -> None:
"""
Refit the TRT graph module with any updates.
This function should be called whenever the weight values get changed but the weight structure remains
@@ -752,7 +752,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModuleself.original_model.cpu()
torch.cuda.empty_cache()
-[docs] def compile(self) -> None:
+[docs] def compile(self) -> None:
"""
(Re)compile the TRT graph module using the PyTorch module.
This function should be called whenever the weight structure get changed (shape, more layers...)
@@ -775,7 +775,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModuleself.original_model.cpu()
torch.cuda.empty_cache()
- def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
+ def _validate_inputs(self, *args: Any, **kwargs: Any) -> None:
if (
not self.arg_inputs
or not MutableTorchTensorRTModule.check_inputs_equal(self.arg_inputs, args)
@@ -787,12 +787,12 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModuleself.refit_state.set_state(RefitFlag.NEEDS_RECOMPILE)
self.store_inputs(args, kwargs)
- def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
+ def store_inputs(self, arg_inputs: Any, kwarg_inputs: Any) -> None:
self.arg_inputs = arg_inputs
self.kwarg_inputs = kwarg_inputs
@staticmethod
- def process_kwarg_inputs(inputs: Any) -> Any:
+ def process_kwarg_inputs(inputs: Any) -> Any:
# Process kwarg inputs to be acceptable for Torch-TensorRT
if isinstance(inputs, dict):
# None should be excluded. AOT compile also does not allow dynamic control flow, bool is also excluded.
@@ -816,7 +816,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule+ "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}"
)
- def forward(self, *args: Any, **kwargs: Any) -> Any:
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
# Step 1: Check whether the input shape has changed
kwargs = MutableTorchTensorRTModule.process_kwarg_inputs(kwargs)
self._validate_inputs(*args, **kwargs)
@@ -849,11 +849,11 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModuleself.run_info = (args, kwargs, result)
return result
- def to(self, device: str) -> None:
+ def to(self, device: str) -> None:
logger.warning("Original PyTorch model is moved. CPU offload may failed.")
self.orignial_model.to(device)
- def __deepcopy__(self, memo: Any) -> Any:
+ def __deepcopy__(self, memo: Any) -> Any:
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
@@ -865,10 +865,10 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule)
return result
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)
- def __getattr__(self, name: str) -> Any:
+ def __getattr__(self, name: str) -> Any:
if name in self.__dict__:
# this object has it
@@ -881,7 +881,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn getattr(self.pytorch_model, name)
- def __delattr__(self, name: str) -> Any:
+ def __delattr__(self, name: str) -> Any:
if name in self.__dict__:
# this object has it
@@ -889,7 +889,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn self.pytorch_model.__delattr__(name)
- def __setattr__(self, name: str, value: Any) -> None:
+ def __setattr__(self, name: str, value: Any) -> None:
# When the module finished initialization, any modification to attributes that does not exist
# in __dict__ will be handled in pytorch module.
if self.init_finished:
@@ -905,7 +905,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModuleobject.__setattr__(self, name, value)
@staticmethod
- def check_inputs_equal(
+ def check_inputs_equal(
input1: Any,
input2: Any,
) -> bool:
@@ -938,7 +938,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn True
@staticmethod
- def save(module: Any, path: str) -> None:
+ def save(module: Any, path: str) -> None:
# Cast the object back to MutableTorchTensorRTModule to save
assert (
not module.settings.use_python_runtime
@@ -964,7 +964,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulemodule.init_finished = True
@staticmethod
- def load(path: str) -> Any:
+ def load(path: str) -> Any:
# When the model get saved, init_finished is set to False.
# Class is restored to MutableTorchTensorRTModule, and some attribute is deleted
module = torch.load(path, weights_only=False)
@@ -986,7 +986,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn module
-def recursively_remove_trigger(obj: Any) -> Any:
+def recursively_remove_trigger(obj: Any) -> Any:
# Not safe: If the object has a circular reference (such as a doubly linkded list), this will cause infinite recursion
if obj.__class__.__name__ == "ChangeTriggerWrapper":
obj = obj.instance
@@ -1008,18 +1008,18 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn obj
-def _make_refit_change_trigger(obj: object, refit_state: RefitState) -> Any:
+def _make_refit_change_trigger(obj: object, refit_state: RefitState) -> Any:
subclass: type = obj.__class__
- class ChangeTriggerWrapper(subclass): # type: ignore
+ class ChangeTriggerWrapper(subclass): # type: ignore
# The reason why we want to inherent to the subclass is that we want the ChangeTriggerWrapper shares all functions
# that an ordinary object has. In this way attributes accessed inside a function will be from the __getattr__function
# of ChangeTriggerWrapper, instead of the object itself, thus be recursively wrapped by ChangeTriggerWrapper.
- def __init__(self, obj: Any):
+ def __init__(self, obj: Any):
object.__setattr__(self, "instance", obj)
- def __getattr__(
+ def __getattr__(
self, name: str
) -> Any: # Called when the attribute does not exist
obj = getattr(self.instance, name)
@@ -1034,7 +1034,7 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulereturn _make_refit_change_trigger(obj, refit_state)
return obj
- def __setattr__(self, name: str, value: Any) -> None:
+ def __setattr__(self, name: str, value: Any) -> None:
# If we need to set __dict__ or instance, we directly set it to the trigger wrapper.
# Enable setting __dict__ is because PyTorch proxy uses __new__ to initialize a shallow copy
# of a module and explicit set the __dict__. If we don't set __dict__ it will get infinite recursion.
@@ -1046,44 +1046,44 @@ Source code for torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModulevalue = recursively_remove_trigger(value)
setattr(self.instance, name, value)
- def __delattr__(self, name: str) -> None:
+ def __delattr__(self, name: str) -> None:
self._on_change()
delattr(
self.instance,
name,
)
- def _on_change(self) -> None:
+ def _on_change(self) -> None:
refit_state.set_state(RefitFlag.UNKNOWN)
logger.info(
"Attribute modification detected. The module will be refitted later."
)
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.instance(*args, **kwargs)
- def _call_impl(self, *args: Any, **kwargs: Any) -> Any:
+ def _call_impl(self, *args: Any, **kwargs: Any) -> Any:
return self.instance._call_impl(*args, **kwargs)
- def forward(self, *args: Any, **kwargs: Any) -> Any:
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.instance.forward(*args, **kwargs)
- def __setitem__(self, item: str, value: Any) -> None:
+ def __setitem__(self, item: str, value: Any) -> None:
self._on_change()
# We want to make sure the original PyTorch model does not have a trigger wrapper
value = recursively_remove_trigger(value)
self.instance.__setitem__(item, value)
- def __getitem__(self, items: str) -> Any:
+ def __getitem__(self, items: str) -> Any:
obj = self.instance.__getitem__(items)
if isinstance(obj, ChangeTriggerWrapper):
return obj
return _make_refit_change_trigger(obj, refit_state)
- def __len__(self) -> int:
+ def __len__(self) -> int:
return len(self.instance)
- def __iter__(self) -> Iterator[Any]:
+ def __iter__(self) -> Iterator[Any]:
return iter(self.instance)
return ChangeTriggerWrapper(obj)
diff --git a/docs/_modules/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.html b/docs/_modules/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.html
index 2ebac528ca..14c1e4d907 100644
--- a/docs/_modules/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.html
+++ b/docs/_modules/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,23 +467,23 @@
Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
-from __future__ import annotations
-
-import logging
-from contextlib import nullcontext
-from tempfile import tempdir
-from typing import Any, Dict, List, Optional, Sequence, Tuple
-
-import tensorrt as trt
-import torch
-import torch_tensorrt
-from torch.nn import Module
-from torch_tensorrt._Device import Device
-from torch_tensorrt._enums import Platform, dtype
-from torch_tensorrt.dynamo._settings import CompilationSettings
-from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
-from torch_tensorrt.logging import TRT_LOGGER
-from torch_tensorrt.runtime._utils import (
+from __future__ import annotations
+
+import logging
+from contextlib import nullcontext
+from tempfile import tempdir
+from typing import Any, Dict, List, Optional, Sequence, Tuple
+
+import tensorrt as trt
+import torch
+import torch_tensorrt
+from torch.nn import Module
+from torch_tensorrt._Device import Device
+from torch_tensorrt._enums import Platform, dtype
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
+from torch_tensorrt.logging import TRT_LOGGER
+from torch_tensorrt.runtime._utils import (
_is_switch_required,
_select_rt_device,
multi_gpu_device_check,
@@ -492,8 +492,8 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
logger = logging.getLogger(__name__)
-class TorchTRTRuntimeStates:
- def __init__(self, new_cudagraphs: bool):
+class TorchTRTRuntimeStates:
+ def __init__(self, new_cudagraphs: bool):
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
self.old_cudagraphs = new_cudagraphs
# Indicates whether pre-allocated output was enabled in the previous execute_engine
@@ -501,7 +501,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
# Indicates whether context has changed
self.context_changed = False
- def set_runtime_states(
+ def set_runtime_states(
self,
new_cudagraphs: bool,
new_pre_allocated_output: bool,
@@ -545,14 +545,14 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
)
-[docs]class PythonTorchTensorRTModule(Module): # type: ignore[misc]
+[docs]class PythonTorchTensorRTModule(Module): # type: ignore[misc]
"""PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is only compatible with
FX / Dynamo / Python deployments. This module cannot be serialized to torchscript via torch.jit.trace for C++ deployment.
"""
-[docs] def __init__(
+[docs] def __init__(
self,
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
@@ -639,16 +639,16 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
- def get_streamable_device_memory_budget(self) -> Any:
+ def get_streamable_device_memory_budget(self) -> Any:
return self.engine.streamable_weights_size
- def get_automatic_device_memory_budget(self) -> Any:
+ def get_automatic_device_memory_budget(self) -> Any:
return self.engine.get_weight_streaming_automatic_budget()
- def get_device_memory_budget(self) -> Any:
+ def get_device_memory_budget(self) -> Any:
return self.engine.weight_streaming_budget_v2
- def set_device_memory_budget(self, budget_bytes: int) -> int:
+ def set_device_memory_budget(self, budget_bytes: int) -> int:
# Recreating the context because weight streaming budget cannot be modified while there are active context.
if self.context is not None:
del self.context
@@ -657,7 +657,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
self.runtime_states.context_changed = True
return budget_bytes
- def _set_device_memory_budget(self, budget_bytes: int) -> int:
+ def _set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.get_streamable_device_memory_budget()
@@ -670,13 +670,13 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
return budget_bytes
- def set_default_device_memory_budget(self) -> int:
+ def set_default_device_memory_budget(self) -> int:
budget_bytes = self.get_automatic_device_memory_budget()
# Set automatic weight streaming budget as default when context is created
logger.debug(f"Weight streaming budget set to {budget_bytes}B")
return self._set_device_memory_budget(budget_bytes)
- def setup_engine(self) -> None:
+ def setup_engine(self) -> None:
assert (
self.target_platform == Platform.current_platform()
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
@@ -710,17 +710,17 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
if torch_tensorrt.runtime.get_cudagraphs_mode():
self.cudagraph = torch.cuda.CUDAGraph()
- def _check_initialized(self) -> None:
+ def _check_initialized(self) -> None:
if not self.initialized:
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
- def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
+ def _on_state_dict(self, state_dict: Dict[str, Any], prefix: str, _: Any) -> None:
state_dict[prefix + "engine"] = self.serialized_engine
state_dict[prefix + "input_names"] = self.input_names
state_dict[prefix + "output_names"] = self.output_names
state_dict[prefix + "platform"] = self.target_platform
- def _load_from_state_dict(
+ def _load_from_state_dict(
self,
state_dict: Dict[str, Any],
prefix: str,
@@ -739,28 +739,28 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
multi_gpu_device_check()
self.setup_engine()
- def __getstate__(self) -> Dict[str, Any]:
+ def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
state.pop("engine", None)
state.pop("context", None)
return state
- def __setstate__(self, state: Dict[str, Any]) -> None:
+ def __setstate__(self, state: Dict[str, Any]) -> None:
self.__dict__.update(state)
self.setup_engine()
- def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
+ def __deepcopy__(self, memo: Any) -> PythonTorchTensorRTModule:
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
result.__setstate__(self.__getstate__())
return result
- def __del__(self) -> None:
+ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()
- def setup_input_tensors(
+ def setup_input_tensors(
self,
contiguous_inputs: List[torch.Tensor],
cudagraphs_enabled: bool,
@@ -811,7 +811,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
input_name, contiguous_inputs[i].data_ptr()
)
- def create_output_tensors(self) -> List[torch.Tensor]:
+ def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []
@@ -824,10 +824,10 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
outputs.append(output)
return outputs
- def set_pre_allocated_outputs(self, enable: bool) -> None:
+ def set_pre_allocated_outputs(self, enable: bool) -> None:
self.use_pre_allocated_outputs = enable
-[docs] def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
+[docs] def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
@@ -981,7 +981,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
)
if self.profiling_enabled:
- import tempfile
+ import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
self.cudagraph.debug_dump(
@@ -1007,7 +1007,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
return outputs
-[docs] def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
+[docs] def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
"""
Enable TensorRT profiling. After calling this function, TensorRT will report
time spent on each layer in stdout for each forward run.
@@ -1019,7 +1019,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
self.profiling_enabled = True
-[docs] def disable_profiling(self) -> None:
+[docs] def disable_profiling(self) -> None:
"""
Disable TensorRT profiling.
"""
@@ -1029,7 +1029,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
self.context = self.engine.create_execution_context()
self.profiling_enabled = False
-[docs] def get_layer_info(self) -> str:
+[docs] def get_layer_info(self) -> str:
"""
Get layer info of the engine. Only support for TRT > 8.2.
"""
@@ -1039,7 +1039,7 @@ Source code for torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule
)
return engine_json
-[docs] def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
+[docs] def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
"""
Validates the input shapes of the forward function has changed
"""
diff --git a/docs/_modules/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.html b/docs/_modules/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.html
index d4539bebff..0840471562 100644
--- a/docs/_modules/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.html
+++ b/docs/_modules/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.html
@@ -9,7 +9,7 @@
- torch_tensorrt.dynamo.runtime._TorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.dynamo.runtime._TorchTensorRTModule — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,23 +467,23 @@
Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
-from __future__ import annotations
-
-import base64
-import copy
-import logging
-import pickle
-from typing import Any, List, Optional, Tuple, Union
-
-import torch
-from torch_tensorrt._Device import Device
-from torch_tensorrt._enums import Platform
-from torch_tensorrt._features import (
+from __future__ import annotations
+
+import base64
+import copy
+import logging
+import pickle
+from typing import Any, List, Optional, Tuple, Union
+
+import torch
+from torch_tensorrt._Device import Device
+from torch_tensorrt._enums import Platform
+from torch_tensorrt._features import (
ENABLED_FEATURES,
for_all_methods,
needs_torch_tensorrt_runtime,
)
-from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo._settings import CompilationSettings
logger = logging.getLogger(__name__)
@@ -519,7 +519,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
[docs]@for_all_methods(needs_torch_tensorrt_runtime)
-class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
+class TorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is fully compatible with both
@@ -539,7 +539,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
"""
- def __init__(
+ def __init__(
self,
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
@@ -609,7 +609,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
):
self.setup_engine()
- def _pack_engine_info(self) -> List[str | bytes]:
+ def _pack_engine_info(self) -> List[str | bytes]:
target_device = (
self.settings.device
if self.settings.device is not None
@@ -643,16 +643,16 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
return engine_info
- def get_streamable_device_memory_budget(self) -> Any:
+ def get_streamable_device_memory_budget(self) -> Any:
return self.engine.streamable_device_memory_budget
- def get_automatic_device_memory_budget(self) -> Any:
+ def get_automatic_device_memory_budget(self) -> Any:
return self.engine.automatic_device_memory_budget
- def get_device_memory_budget(self) -> Any:
+ def get_device_memory_budget(self) -> Any:
return self.engine.device_memory_budget
- def set_device_memory_budget(self, budget_bytes: int) -> int:
+ def set_device_memory_budget(self, budget_bytes: int) -> int:
# Disable weight streaming for invalid budget size
if budget_bytes < 0:
budget_bytes = self.get_streamable_device_memory_budget()
@@ -665,7 +665,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
return budget_bytes
- def setup_engine(self) -> None:
+ def setup_engine(self) -> None:
"""
Setup engine for a module which has deferred engine setup.
@@ -678,19 +678,19 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
return
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
- def encode_metadata(self, metadata: Any) -> str:
+ def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
dumped_metadata = pickle.dumps(metadata)
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
return encoded_metadata
@staticmethod
- def decode_metadata(encoded_metadata: bytes) -> Any:
+ def decode_metadata(encoded_metadata: bytes) -> Any:
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
metadata = pickle.loads(dumped_metadata)
return metadata
- def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
+ def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
if self.engine:
return (
self.name,
@@ -716,7 +716,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
self.output_binding_names,
)
- def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
+ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
self.name = state[0]
if state[1] is not None:
@@ -741,10 +741,10 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
self.input_binding_names = state[2]
self.output_binding_names = state[3]
- def set_pre_allocated_outputs(self, enable: bool) -> None:
+ def set_pre_allocated_outputs(self, enable: bool) -> None:
self.engine.use_pre_allocated_outputs = enable
- def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
+ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
"""Implementation of the forward pass for a TensorRT engine
Args:
@@ -779,7 +779,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
return tuple(outputs)
- def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None:
+ def enable_profiling(self, profiling_results_dir: Optional[str] = None) -> None:
"""Enable the profiler to collect latency information about the execution of the engine
Traces can be visualized using https://ui.perfetto.dev/ or compatible alternatives
@@ -794,14 +794,14 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
self.engine.profile_path_prefix = profiling_results_dir
self.engine.enable_profiling()
- def disable_profiling(self) -> None:
+ def disable_profiling(self) -> None:
"""Disable the profiler"""
if self.engine is None:
raise RuntimeError("Engine has not been initialized yet.")
self.engine.disable_profiling()
- def get_layer_info(self) -> str:
+ def get_layer_info(self) -> str:
"""Get a JSON string containing the layer information encoded by the TensorRT engine in this module
Returns:
@@ -814,7 +814,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
layer_info: str = self.engine.get_engine_layer_info()
return layer_info
- def dump_layer_info(self) -> None:
+ def dump_layer_info(self) -> None:
"""Dump layer information encoded by the TensorRT engine in this module to STDOUT"""
if self.engine is None:
raise RuntimeError("Engine has not been initialized yet.")
@@ -822,7 +822,7 @@ Source code for torch_tensorrt.dynamo.runtime._TorchTensorRTModule
self.engine.dump_engine_layer_info()
@staticmethod
- def _pack_binding_names(binding_names: List[str]) -> str:
+ def _pack_binding_names(binding_names: List[str]) -> str:
delim = torch.ops.tensorrt.SERIALIZED_ENGINE_BINDING_DELIM()[0]
packed_bindings: str = delim.join(binding_names)
return packed_bindings
diff --git a/docs/_modules/torch_tensorrt/fx/fx2trt.html b/docs/_modules/torch_tensorrt/fx/fx2trt.html
index e657aedecf..fae3065cfc 100644
--- a/docs/_modules/torch_tensorrt/fx/fx2trt.html
+++ b/docs/_modules/torch_tensorrt/fx/fx2trt.html
@@ -9,7 +9,7 @@
- torch_tensorrt.fx.fx2trt — Torch-TensorRT v2.6.0.dev0+50f29cb documentation
+ torch_tensorrt.fx.fx2trt — Torch-TensorRT v2.6.0.dev0+69c83d4 documentation
@@ -272,7 +272,7 @@
- v2.6.0.dev0+50f29cb
+ v2.6.0.dev0+69c83d4
@@ -467,26 +467,26 @@
Source code for torch_tensorrt.fx.fx2trt
-import logging
-import os
-import warnings
-from datetime import datetime
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
+import logging
+import os
+import warnings
+from datetime import datetime
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
-import numpy
+import numpy
# @manual=//deeplearning/trt/python:py_tensorrt
-import tensorrt as trt
-import torch
-import torch.fx
-from torch._ops import OpOverload
-from torch.fx.node import _get_qualified_name
-from torch.fx.passes.shape_prop import TensorMetadata
-
-from .converter_registry import CONVERTERS
-from .input_tensor_spec import InputTensorSpec
-from .observer import Observer
-from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
+import tensorrt as trt
+import torch
+import torch.fx
+from torch._ops import OpOverload
+from torch.fx.node import _get_qualified_name
+from torch.fx.passes.shape_prop import TensorMetadata
+
+from .converter_registry import CONVERTERS
+from .input_tensor_spec import InputTensorSpec
+from .observer import Observer
+from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -495,15 +495,15 @@ Source code for torch_tensorrt.fx.fx2trt
)
-[docs]class TRTInterpreterResult(NamedTuple):
+[docs]class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
output_names: Sequence[str]
serialized_cache: bytearray
-[docs]class TRTInterpreter(torch.fx.Interpreter):
- def __init__(
+[docs]class TRTInterpreter(torch.fx.Interpreter):
+ def __init__(
self,
module: torch.fx.GraphModule,
input_specs: List[InputTensorSpec],
@@ -548,7 +548,7 @@ Source code for torch_tensorrt.fx.fx2trt
dict()
)
- def validate_input_specs(self):
+ def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
assert (
@@ -605,7 +605,7 @@ Source code for torch_tensorrt.fx.fx2trt
len(shape_ranges) == 0
), "shape_ranges are provided for input that doesn't have dynamic dim."
- def validate_conversion(self):
+ def validate_conversion(self):
missing_converter = set()
for node in self.module.graph.nodes:
@@ -621,7 +621,7 @@ Source code for torch_tensorrt.fx.fx2trt
return missing_converter
- def run(
+ def run(
self,
max_batch_size=64,
max_workspace_size=1 << 25,
@@ -739,7 +739,7 @@ Source code for torch_tensorrt.fx.fx2trt
engine, self._input_names, self._output_names, serialized_cache
)
- def run_node(self, n):
+ def run_node(self, n):
self._cur_node_name = str(n)
# add "_itensor_to_tensor_meta"
kwargs = dict(n.kwargs)
@@ -759,7 +759,7 @@ Source code for torch_tensorrt.fx.fx2trt
return trt_node
- def placeholder(self, target, args, kwargs):
+ def placeholder(self, target, args, kwargs):
self._input_names.append(target)
shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[
self.input_specs_iter
@@ -780,7 +780,7 @@ Source code for torch_tensorrt.fx.fx2trt