Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent 93e83ea commit 1f927fa
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 1 deletion.
53 changes: 52 additions & 1 deletion tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,15 @@ class TensorClass:
return_early: bool = False,
share_non_tensor: bool = False,
) -> T: ...
dumps = save
def dumps(
self,
prefix: str | None = None,
copy_existing: bool = False,
*,
num_threads: int = 0,
return_early: bool = False,
share_non_tensor: bool = False,
) -> T: ...
def memmap(
self,
prefix: str | None = None,
Expand Down Expand Up @@ -892,6 +900,14 @@ class TensorClass:
*,
default: str | CompatibleType | None = None,
) -> T: ...
def clamp(
self,
min: TensorDictBase | torch.Tensor = None,
max: TensorDictBase | torch.Tensor = None,
*,
out=None,
): ...
def logsumexp(self, dim=None, keepdim=False, *, out=None): ...
def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: ...
def clamp_max(
self,
Expand Down Expand Up @@ -944,6 +960,27 @@ class TensorClass:
def to_namedtuple(self, dest_cls: type | None = None): ...
@classmethod
def from_namedtuple(cls, named_tuple, *, auto_batch_size: bool = False): ...
def from_tuple(
cls,
obj,
*,
auto_batch_size: bool = False,
batch_dims: int | None = None,
device: torch.device | None = None,
batch_size: torch.Size | None = None,
): ...
def logical_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: ...
def bitwise_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: ...
@classmethod
def from_struct_array(
cls, struct_array: np.ndarray, device: torch.device | None = None
Expand Down Expand Up @@ -987,6 +1024,20 @@ class TensorClass:
strict: bool = True,
reproduce_struct: bool = False,
): ...
def separates(
self,
*keys: NestedKey,
default: Any = NO_DEFAULT,
strict: bool = True,
filter_empty: bool = True,
) -> T: ...
def norm(
self,
*,
out=None,
dtype: torch.dtype | None = None,
): ...
def softmax(self, dim: int, dtype: torch.dtype | None = None): ...
@property
def is_locked(self) -> bool: ...
@is_locked.setter
Expand Down
63 changes: 63 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from __future__ import annotations

import argparse
import ast
import contextlib
import dataclasses
import inspect
import os
import pathlib
import pickle
import re
import sys
Expand Down Expand Up @@ -61,6 +63,67 @@
]


def _get_methods_from_pyi(file_path):
"""
Reads a .pyi file and returns a set of method names.
Args:
file_path (str): Path to the .pyi file.
Returns:
set: A set of method names.
"""
with open(file_path, "r") as f:
tree = ast.parse(f.read())

methods = set()
for node in tree.body:
if isinstance(node, ast.ClassDef):
for child_node in node.body:
if isinstance(child_node, ast.FunctionDef):
methods.add(child_node.name)

return methods


def _get_methods_from_class(cls):
"""
Returns a set of method names from a given class.
Args:
cls (class): The class to get methods from.
Returns:
set: A set of method names.
"""
methods = set()
for name in dir(cls):
attr = getattr(cls, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
methods.add(name)

return methods


def test_tensorclass_stub_methods():
tensorclass_pyi_path = (
pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi"
)
tensorclass_methods = _get_methods_from_pyi(str(tensorclass_pyi_path))

from tensordict import TensorDict

tensordict_methods = _get_methods_from_class(TensorDict)

missing_methods = tensordict_methods - tensorclass_methods
missing_methods = [
method for method in missing_methods if (not method.startswith("_"))
]

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")


def _make_data(shape):
return MyData(
X=torch.rand(*shape),
Expand Down

0 comments on commit 1f927fa

Please sign in to comment.