Skip to content

Commit

Permalink
basics
Browse files Browse the repository at this point in the history
ghstack-source-id: 8c34373f4fcd788636be7b87e8a554017df0c746
Pull Request resolved: #873
  • Loading branch information
vmoens committed Jul 11, 2024
1 parent 02ff686 commit df9c196
Show file tree
Hide file tree
Showing 14 changed files with 834 additions and 343 deletions.
96 changes: 0 additions & 96 deletions benchmarks/tensorclass/test_torch_functions.py

This file was deleted.

2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass
from tensordict.utils import (
assert_allclose_td,
assert_close,
is_batchedtensor,
is_tensorclass,
lazy_legacy,
Expand All @@ -43,6 +44,7 @@
"TensorDict",
"TensorDictBase",
"assert_allclose_td",
"assert_close",
"dense_stack_tds",
"is_batchedtensor",
"is_tensor_collection",
Expand Down
15 changes: 8 additions & 7 deletions tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def context_decorator(ctx, func):
be a multi-shot context manager that can be directly invoked multiple times)
or a callable that produces a context manager.
"""
assert not (callable(ctx) and hasattr(ctx, "__enter__")), (
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)
if callable(ctx) and hasattr(ctx, "__enter__"):
raise RuntimeError(
f"Passed in {ctx} is both callable and also a valid context manager "
"(has __enter__), making it ambiguous which interface to use. If you "
"intended to pass a context manager factory, rewrite your call as "
"context_decorator(lambda: ctx()); if you intended to pass a context "
"manager directly, rewrite your call as context_decorator(lambda: ctx)"
)

if not callable(ctx):

Expand Down
28 changes: 10 additions & 18 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False
from tensordict._C import ( # @manual=//tensordict:_C
_unravel_key_to_tuple,
unravel_key_list,
)
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict.base import (
_is_leaf_nontensor,
Expand All @@ -72,6 +68,7 @@
_renamed_inplace_method,
_shape,
_td_fields,
_unravel_key_to_tuple,
as_decorator,
cache,
convert_ellipsis_to_idx,
Expand All @@ -85,6 +82,7 @@
KeyedJaggedTensor,
lock_blocked,
NestedKey,
unravel_key_list,
)
from torch import Tensor

Expand Down Expand Up @@ -1168,7 +1166,9 @@ def _add_batch_dim(self, *, in_dim, vmap_level):
td._fast_apply(
lambda _arg: _add_batch_dim(_arg, in_dim, vmap_level),
batch_size=[b for i, b in enumerate(td.batch_size) if i != in_dim],
names=[name for i, name in enumerate(td.names) if i != in_dim],
names=[name for i, name in enumerate(td.names) if i != in_dim]
if self._has_names()
else None,
)
for td in td.tensordicts
]
Expand Down Expand Up @@ -1313,7 +1313,7 @@ def contiguous(self) -> T:
source=source,
batch_size=batch_size,
device=device,
names=self.names,
names=self.names if self._has_names() else None,
lock=self.is_locked,
)
return out
Expand Down Expand Up @@ -1377,10 +1377,6 @@ def _check_new_batch_size(self, new_size: torch.Size) -> None:
super()._check_new_batch_size(new_size)

def _change_batch_size(self, new_size: torch.Size) -> None:
if not hasattr(self, "_orig_batch_size"):
self._orig_batch_size = self.batch_size
elif self._orig_batch_size == new_size:
del self._orig_batch_size
self._batch_size = new_size

def keys(
Expand Down Expand Up @@ -1552,7 +1548,7 @@ def _multithread_rebuild(
# We know batch_size is None, this has been checked earlier
batch_size: Sequence[int] | None = None,
device: torch.device | None = NO_DEFAULT,
names: Sequence[str] | None = None,
names: Sequence[str] | None = NO_DEFAULT,
inplace: bool = False,
checked: bool = False,
out: TensorDictBase | None = None,
Expand Down Expand Up @@ -1603,7 +1599,7 @@ def _multithread_rebuild(
)
else:
out = self
if names is not None:
if names is not NO_DEFAULT:
out.names = names
return out

Expand All @@ -1613,7 +1609,7 @@ def _apply_nest(
*others: T,
batch_size: Sequence[int] | None = None,
device: torch.device | None = NO_DEFAULT,
names: Sequence[str] | None = None,
names: Sequence[str] | None = NO_DEFAULT,
inplace: bool = False,
checked: bool = False,
call_on_nested: bool = False,
Expand Down Expand Up @@ -1719,7 +1715,7 @@ def _apply_nest(
)
else:
out = self
if names is not None:
if names is not NO_DEFAULT:
out.names = names
return out

Expand Down Expand Up @@ -2980,10 +2976,6 @@ def _rename_subtds(self, names):
)

def _change_batch_size(self, new_size: torch.Size) -> None:
if not hasattr(self, "_orig_batch_size"):
self._orig_batch_size = self.batch_size
elif self._orig_batch_size == new_size:
del self._orig_batch_size
self._batch_size = new_size

def _get_str(self, key, default):
Expand Down
10 changes: 6 additions & 4 deletions tensordict/_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@


def _str_to_dict(str_spec: str) -> Tuple[List[str], str]:
assert str_spec[1] == "("
assert str_spec[-1] == ")"
if str_spec[1] != "(" or str_spec[-1] != ")":
raise ValueError(
f"string must have '(' as a second character and ')' in last position. Got {str_spec}."
)
context_and_child_strings = str_spec[2:-1]

child_strings = []
Expand Down Expand Up @@ -92,7 +94,7 @@ def _tensordict_flatten(d: TensorDict) -> Tuple[List[Any], Context]:
return values, {
"keys": keys,
"batch_size": d.batch_size,
"names": d.names,
"names": d.names if d._has_names() else None,
"device": d.device,
"constructor": _constructor(type(d)),
"non_tensor_data": d.non_tensor_items(),
Expand Down Expand Up @@ -159,7 +161,7 @@ def _td_flatten_with_keys(
return [(MappingKey(k), v) for k, v in zip(keys, values)], {
"keys": keys,
"batch_size": d.batch_size,
"names": d.names,
"names": d.names if d._has_names() else None,
"device": d.device,
"constructor": _constructor(type(d)),
"non_tensor_data": d.non_tensor_items(),
Expand Down
Loading

0 comments on commit df9c196

Please sign in to comment.