Skip to content

Commit

Permalink
[Doc] Doc revamp (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 4, 2024
1 parent 750a114 commit 3262b39
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 132 deletions.
9 changes: 6 additions & 3 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
Welcome to the TensorDict Documentation!
========================================

`TensorDict` is a dictionary-like class that inherits properties from tensors, such as indexing, shape operations, casting to device etc.
`TensorDict` is a dictionary-like class that inherits properties from tensors,
such as indexing, shape operations, casting to device etc.

The main purpose of TensorDict is to make code-bases more *readable* and *modular* by abstracting away tailored operations:
The main purpose of TensorDict is to make code-bases more *readable* and *modular*
by abstracting away tailored operations:

>>> for i, tensordict in enumerate(dataset):
... # the model reads and writes tensordicts
Expand All @@ -19,7 +21,8 @@ The main purpose of TensorDict is to make code-bases more *readable* and *modula
... optimizer.zero_grad()

With this level of abstraction, one can recycle a training loop for highly heterogeneous task.
Each individual step of the training loop (data collection and transform, model prediction, loss computation etc.)
Each individual step of the training loop (data collection and transform, model
prediction, loss computation etc.)
can be tailored to the use case at hand without impacting the others.
For instance, the above example can be easily used across classification and segmentation tasks, among many others.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ API Reference

tensordict
nn
prototype
tensorclass
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
.. currentmodule:: tensordict.prototype
.. currentmodule:: tensordict

tensorclass prototype
=====================
tensorclass
===========

The :obj:`@tensorclass` decorator helps you build custom classes that inherit the
behaviour from :obj:`TensorDict` while being able to restrict the possible entries
to a predefined set or implement custom methods for your class.
Like :obj:`TensorDict`, :obj:`@tensorclass` supports nesting, indexing, reshaping,
item assignment. It also supports tensor operations like clone, squeeze, cat, split and many more.
:obj:`@tensorclass` allows non-tensor entries,
however all the tensor operations are strictly restricted to tensor attributes. One
needs to implement their custom methods for non-tensor data. It is important to note that
:obj:`@tensorclass` does not enforce strict type matching
The ``@tensorclass`` decorator helps you build custom classes that inherit the
behaviour from :class:`~tensordict.TensorDict` while being able to restrict
the possible entries to a predefined set or implement custom methods for your class.
Like :class:`~tensordict.TensorDict`, ``@tensorclass`` supports nesting, indexing, reshaping,
item assignment. It also supports tensor operations like clone, squeeze, cat,
split and many more. ``@tensorclass`` allows non-tensor entries,
however all the tensor operations are strictly restricted to tensor attributes.
One needs to implement their custom methods for non-tensor data.
It is important to note that ``@tensorclass`` does not enforce strict type matching

.. code-block::
Expand All @@ -27,7 +27,6 @@ needs to implement their custom methods for non-tensor data. It is important to
... intdata: torch.Tensor
... non_tensordata: str
... nested: Optional[MyData] = None
... # sparse_data: Optional[KeyedJaggedTensor] = None
...
... def check_nested(self):
... assert self.nested is not None
Expand Down Expand Up @@ -71,7 +70,7 @@ needs to implement their custom methods for non-tensor data. It is important to
is_shared=False)
:obj:`@tensorclass` supports indexing. Internally the tensor objects gets indexed,
``@tensorclass`` supports indexing. Internally the tensor objects gets indexed,
however the non-tensor data remains the same

.. code-block::
Expand All @@ -93,7 +92,7 @@ however the non-tensor data remains the same
device=None,
is_shared=False)
:obj:`@tensorclass` also supports setting and resetting attributes, even for nested objects.
``@tensorclass`` also supports setting and resetting attributes, even for nested objects.

.. code-block::
Expand Down Expand Up @@ -123,7 +122,7 @@ however the non-tensor data remains the same
>>> print("data.nested.non_tensordata:", repr(data.nested.non_tensordata))
data.nested.non_tensordata: 'nested_test_changed'
:obj:`@tensorclass` supports multiple torch operations over the shape and device
``@tensorclass`` supports multiple torch operations over the shape and device
of its content, such as `stack`, `cat`, `reshape` or `to(device)`. To get
a full list of the supported operations, check the tensordict documentation.

Expand All @@ -150,9 +149,28 @@ Here is an example:
device=None,
is_shared=False)
Serialization
~~~~~~~~~~~~~

Saving a tensorclass instance can be achieved with the `memmap` method.
The saving strategy is as follows: tensor data will be saved using memory-mapped
tensors, and non-tensor data that can be serialized using a json format will
be saved as such. Other data types will be saved using :func:`~torch.save`, which
relies on `pickle`.

Deserializing a `tensorclass` can be done via :meth:`~tensordict.TensorDict.load_memmap`.
The instance created will have the same type as the one saved provided that
the `tensorclass` is available in the working environment:

>>> data.memmap("path/to/saved/directory")
>>> data_loaded = TensorDict.load_memmap("path/to/saved/directory")
>>> assert isinstance(data_loaded, type(data))


Edge cases
~~~~~~~~~~
:obj:`@tensorclass` supports equality and inequality operators, even for

``@tensorclass`` supports equality and inequality operators, even for
nested objects. Note that the non-tensor/ meta data is not validated.
This will return a tensor class object with boolean values for
tensor attributes and None for non-tensor attributes
Expand All @@ -178,7 +196,7 @@ Here is an example:
device=None,
is_shared=False)
:obj:`@tensorclass` supports setting an item. However, while setting an item
``@tensorclass`` supports setting an item. However, while setting an item
the identity check of non-tensor / meta data is done instead of equality to
avoid performance issues. User needs to make sure that the non-tensor data
of an item matches with the object to avoid discrepancies.
Expand All @@ -194,7 +212,7 @@ thrown
>>> data[0] = data2[0]
UserWarning: Meta data at 'non_tensordata' may or may not be equal, this may result in undefined behaviours
Even though :obj:`@tensorclass` supports torch functions like cat and stack, the
Even though ``@tensorclass`` supports torch functions like cat and stack, the
non-tensor / meta data is not validated. The torch operation is performed on the
tensor data and while returning the output, the non-tensor / meta data of the first
tensor class object is considered. User needs to make sure that all the
Expand Down Expand Up @@ -223,9 +241,9 @@ Here is an example:
device=None,
is_shared=False)
:obj:`@tensorclass` also supports pre-allocation, you can initialize
``@tensorclass`` also supports pre-allocation, you can initialize
the object with attributes being None and later set them. Note that while
initializing, internally the None attributes will be saved as non-tensor / meta data
initializing, internally the ``None`` attributes will be saved as non-tensor / meta data
and while resetting, based on the type of the value of the attribute,
it will be saved as either tensor data or non-tensor / meta data

Expand Down
51 changes: 41 additions & 10 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
tensordict package
==================

The `TensorDict` class simplifies the process of passing multiple tensors
The :class:`~tensordict.TensorDict` class simplifies the process of passing multiple tensors
from module to module by packing them in a dictionary-like object that inherits features from
regular pytorch tensors.

Expand All @@ -12,25 +12,55 @@ regular pytorch tensors.
:toctree: generated/
:template: td_template.rst

TensorDictBase
TensorDict
SubTensorDict
LazyStackedTensorDict
PersistentTensorDict
TensorDictParams

TensorDict as a context manager
-------------------------------

:class:`~tensordict.TensorDict` can be used as a context manager in situations
where an action has to be done and then undone. This include temporarily
locking/unlocking a tensordict

>>> data.lock_() # data.set will result in an exception
>>> with data.unlock_():
... data.set("key", value)
>>> assert data.is_locked()

or to execute functional calls with a TensorDict instance containing the
parameters and buffers of a model:

>>> params = TensorDict.from_module(module).clone()
>>> params.zero_()
>>> with params.to_module(module):
... y = module(x)

In the first example, we can modify the tensordict `data` because we have
temporarily unlocked it. In the second example, we populate the module with the
parameters and buffers contained in the `params` tensordict instance, and reset
the original parameters after this call is completed.

Memory-mapped tensors
---------------------

:obj:`tensordict` offers the :class:`~tensordict.MemoryMappedTensor` primitive which allows you to work
with tensors stored in physical memory in a handy way. The main advantages of :class:`~tensordict.MemoryMappedTensor`
are its easiness of construction (no need to handle the storage of a tensor), the possibility to
work with big contiguous data that would not fit in memory, an efficient (de)serialization across processes and
efficient indexing of stored tensors.
`tensordict` offers the :class:`~tensordict.MemoryMappedTensor` primitive which
allows you to work with tensors stored in physical memory in a handy way.
The main advantages of :class:`~tensordict.MemoryMappedTensor`
are its easiness of construction (no need to handle the storage of a tensor),
the possibility to work with big contiguous data that would not fit in memory,
an efficient (de)serialization across processes and efficient indexing of
stored tensors.

If all workers have access to the same storage, passing a :class:`~tensordict.MemoryMappedTensor`
If all workers have access to the same storage (both in multiprocess and distributed
settings), passing a :class:`~tensordict.MemoryMappedTensor`
will just consist in passing a reference to a file on disk plus a bunch of
extra meta-data for reconstructing it when sent across processes or workers on
a same machine (both in multiprocess and distributed settings). The same goes
with indexed memory-mapped tensors.
extra meta-data for reconstructing it. The same goes with indexed memory-mapped
tensors as long as the data-pointer of their storage is the same as the original
one.

Indexing memory-mapped tensors is much faster than loading several independent files from
the disk and does not require to load the full content of the array in memory.
Expand All @@ -47,6 +77,7 @@ However, physical storage of PyTorch tensors should not be any different:

MemoryMappedTensor


Utils
-----

Expand Down
Loading

0 comments on commit 3262b39

Please sign in to comment.