Skip to content

v0.6.0: Export, streaming and `CudaGraphModule`

Compare
Choose a tag to compare
@vmoens vmoens released this 21 Oct 16:38
8c65dcb

What's Changed

TensorDict 0.6.0 makes the @dispatch decorator compatible with torch.export and related APIs,
allowing you to get rid of tensordict altogether when exporting your models:

from torch.export import export

model = Seq(
    # 1. A small network for embedding
    Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
    Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
    Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
    # 2. Extracting params
    Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
    # 3. Probabilistic module
    Prob(
        in_keys=["loc", "scale"],
        out_keys=["sample"],
        distribution_class=dists.Normal,
    ),
)

model_export = export(model, args=(), kwargs={"x": x})

See our new tutorial to learn more about this feature.

The library integration with the PT2 stack is also further improved by the introduction of CudaGraphModule,
which can be used to speed-up model execution under a certain set of assumptions; mainly that the inputs and outputs
are non-differentiable, that they are all tensors or constant and that the whole graph can be executed on cuda with
buffers of constant shape (ie, dynamic shape is not allowed).

We also introduce a new tutorial on streaming tensordicts.

Note: The aarch64 binaries are attached to these release notes and not available in PyPI at the moment.

Deprecations

  • [Deprecate] Make calls to make_functional error #1034 by @vmoens
  • [Deprecation] Act warned deprecations for v0.6 #1001 by @vmoens
  • [Refactor] make TD.get default to None, like dict (#948) by @vmoens

Features

Code improvements

Fixes

  • [BugFix] Add nullbyte in memmap files to make fbcode happy (#943) by @vmoens
  • [BugFix] Add sync to cudagraph module (#1026) by @vmoens
  • [BugFix] Another compiler fix for older pytorch #980 by @vmoens
  • [BugFix] Compatibility with non-tensor inputs in CudaGraphModule #1039 by @vmoens
  • [BugFix] Deserializing a consolidated TD reproduces a consolidated TD #1019 by @vmoens
  • [BugFix] Fix foreach_copy for older versions of PT #1035 by @vmoens
  • [BugFix] Fix buffer identity in Params._apply (#1027) by @vmoens
  • [BugFix] Fix key errors catch in del_ and related (#949) by @vmoens
  • [BugFix] Fix number check in array parsing (np>=2 compatibility) #999 by @vmoens
  • [BugFix] Fix pre 2.1 _apply compatibility #1050 by @vmoens
  • [BugFix] Fix select in tensorclass (#936) by @vmoens
  • [BugFix] Fix td device sync when error is raised #988 by @vmoens
  • [BugFix] Fix tree_leaves import for older versions of PT #995 by @vmoens
  • [BugFix] Fix vmap monkey patching #1009 by @vmoens
  • [BugFix] Make probabilistic sequential modules compatible with compile #1030 by @vmoens
  • [BugFix] Other dynamo fixes #977 by @vmoens
  • [BugFix] Propagate maybe_dense_stack in _stack #1036 by @vmoens
  • [BugFix] Regular swap_tensor for to_module in dynamo (#963) by @vmoens
  • [BugFix] Remove ForkingPickler to account for change of API in torch.mp #998 by @vmoens
  • [BugFix] Remove forkingpickler (#1049) by @bhack
  • [BugFix] Resilient deterministic_sample for CompositeDist #1000 by @vmoens
  • [BugFix] Simple syncs (#942) by @vmoens
  • [BugFix] Softly revert get changes (#950) by @vmoens
  • [BugFix] TDParams.to(device) works as nn.Module, not TDParams contained TD #1025 by @vmoens
  • [BugFix] Use separate streams for cudagraph warmup #1010 by @vmoens
  • [BugFix] dynamo compat refactors #975 by @vmoens
  • [BugFix] resilient _exclude_td_from_pytree #1038 by @vmoens
  • [BugFix] restrict usage of Buffers to non-batched, non-tracked tensors #979 by @vmoens

Doc

Performance

Not user facing

New Contributors

Full Changelog: v0.5.0...v0.6.0

Co-authored-by: Vincent Moens [email protected] by @albertbou92