v0.6.0: Export, streaming and `CudaGraphModule`
What's Changed
TensorDict 0.6.0 makes the @dispatch
decorator compatible with torch.export
and related APIs,
allowing you to get rid of tensordict altogether when exporting your models:
from torch.export import export
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
# 3. Probabilistic module
Prob(
in_keys=["loc", "scale"],
out_keys=["sample"],
distribution_class=dists.Normal,
),
)
model_export = export(model, args=(), kwargs={"x": x})
See our new tutorial to learn more about this feature.
The library integration with the PT2 stack is also further improved by the introduction of CudaGraphModule
,
which can be used to speed-up model execution under a certain set of assumptions; mainly that the inputs and outputs
are non-differentiable, that they are all tensors or constant and that the whole graph can be executed on cuda with
buffers of constant shape (ie, dynamic shape is not allowed).
We also introduce a new tutorial on streaming tensordicts.
Note: The aarch64
binaries are attached to these release notes and not available in PyPI at the moment.
Deprecations
- [Deprecate] Make calls to make_functional error #1034 by @vmoens
- [Deprecation] Act warned deprecations for v0.6 #1001 by @vmoens
- [Refactor] make TD.get default to None, like dict (#948) by @vmoens
Features
- [Feature] Allow to specify log_prob_key in CompositeDistribution (#961) by @albertbou92
- [Feature] Better typing for tensorclass #983 by @vmoens
- [Feature] Cudagraphs (#986) by @vmoens
- [Feature] Densify lazy tensordicts #955 by @vmoens
- [Feature] Frozen tensorclass #984 by @vmoens
- [Feature] Make NonTensorData a callable (#939) by @vmoens
- [Feature] NJT with lengths #1021 by @vmoens
- [Feature] Non-blocking for consolidated TD #1020 by @vmoens
- [Feature] Propagate
existsok
in memmap* methods #990 by @vmoens - [Feature] TD+NJT to(device) support #1022 by @vmoens
- [Feature] TensorDict.record_stream #1016 by @vmoens
- [Feature] Unify composite dist method signatures with other dists (#981) Co-authored-by: Vincent Moens [email protected]^M
- [Feature] foreach_copy for update_ #1032 by @vmoens
- [Feature]
data_ptr()
method #1024 by @vmoens - [Feature]
inplace
arg in TDM constructor #992 by @vmoens - [Feature]
selected_out_keys
arg in TDS constructor #993 by @vmoens - [Feature] better sync and instantiation of cudagraphs (#1013) by @vmoens
- [Feature] callables for merge_tensordicts #1033 by @vmoens
- [Feature] cat and stack_from_tensordict #1018 by @vmoens
- [Feature] cat_tensors and stack_tensors #1017 by @vmoens
- [Feature] from_struct_array and to_struct_array (#938) by @vmoens
- [Feature] give a
__name__
to TDModules #1045 by @vmoens - [Feature] param_count #1046 by @vmoens
- [Feature] sorted keys, values and items #965 by @vmoens
- [Feature] str2td #953 by @vmoens
- [Feature] torch.export and onnx compatibility #991 by @vmoens
Code improvements
- [Quality] Better error for mismatching TDs (#964) by @vmoens
- [Quality] Better type hints for
__init__
(#1014) by @vmoens - [Quality] Expose private classmethods (#982) by @vmoens
- [Quality] Fewer recompiles with tensordict (#1015) by @vmoens
- [Quality] Type checks #976 by @vmoens
- [Refactor, Tests] Move TestCudagraphs by @vmoens
- [Refactor, Tests] Move TestCudagraphs #1007 by @vmoens
- [Refactor] Make @Tensorclass work properly with pyright (#1042) by @maxim
- [Refactor] Update nn inline_inbuilt check #1029 by @vmoens
- [Refactor] Use IntEnum for interaction types (#989) by @vmoens
- [Refactor] better AddStateIndependentNormalScale #1028 by @vmoens
Fixes
- [BugFix] Add nullbyte in memmap files to make fbcode happy (#943) by @vmoens
- [BugFix] Add sync to cudagraph module (#1026) by @vmoens
- [BugFix] Another compiler fix for older pytorch #980 by @vmoens
- [BugFix] Compatibility with non-tensor inputs in CudaGraphModule #1039 by @vmoens
- [BugFix] Deserializing a consolidated TD reproduces a consolidated TD #1019 by @vmoens
- [BugFix] Fix foreach_copy for older versions of PT #1035 by @vmoens
- [BugFix] Fix buffer identity in Params._apply (#1027) by @vmoens
- [BugFix] Fix key errors catch in del_ and related (#949) by @vmoens
- [BugFix] Fix number check in array parsing (np>=2 compatibility) #999 by @vmoens
- [BugFix] Fix pre 2.1 _apply compatibility #1050 by @vmoens
- [BugFix] Fix select in tensorclass (#936) by @vmoens
- [BugFix] Fix td device sync when error is raised #988 by @vmoens
- [BugFix] Fix tree_leaves import for older versions of PT #995 by @vmoens
- [BugFix] Fix vmap monkey patching #1009 by @vmoens
- [BugFix] Make probabilistic sequential modules compatible with compile #1030 by @vmoens
- [BugFix] Other dynamo fixes #977 by @vmoens
- [BugFix] Propagate maybe_dense_stack in _stack #1036 by @vmoens
- [BugFix] Regular swap_tensor for to_module in dynamo (#963) by @vmoens
- [BugFix] Remove ForkingPickler to account for change of API in torch.mp #998 by @vmoens
- [BugFix] Remove forkingpickler (#1049) by @bhack
- [BugFix] Resilient deterministic_sample for CompositeDist #1000 by @vmoens
- [BugFix] Simple syncs (#942) by @vmoens
- [BugFix] Softly revert get changes (#950) by @vmoens
- [BugFix] TDParams.to(device) works as nn.Module, not TDParams contained TD #1025 by @vmoens
- [BugFix] Use separate streams for cudagraph warmup #1010 by @vmoens
- [BugFix] dynamo compat refactors #975 by @vmoens
- [BugFix] resilient _exclude_td_from_pytree #1038 by @vmoens
- [BugFix] restrict usage of Buffers to non-batched, non-tracked tensors #979 by @vmoens
Doc
- [Doc] Broken links in GETTING_STARTED.md (#945) by @vmoens
- [Doc] Fail-on-warning in sphinx #1005 by @vmoens
- [Doc] Fix tutorials #1002 by @vmoens
- [Doc] Refactor README and add GETTING_STARTED.md (#944) by @vmoens
- [Doc] Streaming tensordicts #956 by @vmoens
- [Doc] export tutorial, TDM tuto refactoring #994 by @vmoens
Performance
Not user facing
- [Benchmark] Benchmark H2D transfer #1044 by @vmoens
- [CI, BugFix] Fix nightly build (#941) by @vmoens
- [CI] Add aarch64-linux wheels (#987) by @vmoens
- [CI] Fix versioning of h2d tests #1053 by @vmoens
- [CI] Fix windows wheels #1006 by @vmoens
- [CI] Upgrade 3.8 workflows (#967) by @vmoens
- [Minor, Format] Fix fbcode lint (#940) by @vmoens
- [Minor] Refactor is_dynamo_compiling for older torch versions (#978) by @vmoens
- [Setup] Correct read_file encoding in setup (#962) by @vmoens
- [Test] Keep a tight control over warnings (#951) by @vmoens
- [Test] Make h5py tests optional if no h5py installed (#947) by @vmoens
- [Test] Mark MP tests as slow (#946) by @vmoens
- [Test] Rename duplicated test #997 by @vmoens
- [Test] Skip compile tests that require 2.5 for stable #996 by @vmoens
- [Versioning] Versions for 0.6 (#1052) by @vmoens
New Contributors
Full Changelog: v0.5.0...v0.6.0
Co-authored-by: Vincent Moens [email protected] by @albertbou92