From b5597fd196cebe151c71738babed2fccb804f98c Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 11 Jul 2023 12:35:05 -0400 Subject: [PATCH 1/2] init --- tensordict/nn/common.py | 6 ++---- tensordict/tensordict.py | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 07ad01bc5..38eb07c9c 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -13,7 +13,7 @@ import torch from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads -from tensordict._tensordict import unravel_key_list +from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.nn.functional_modules import make_functional @@ -248,9 +248,7 @@ def wrapper(_self, *args: Any, **kwargs: Any) -> Any: if isinstance(dest, str): dest = getattr(_self, dest) for key in source: - expected_key = ( - self.separator.join(key) if isinstance(key, tuple) else key - ) + expected_key = self.separator.join(_unravel_key_to_tuple(key)) if len(args): tensordict_values[key] = args[0] args = args[1:] diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index f389c7109..7f43283b8 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -3872,7 +3872,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): "Cannot pass both batch_size and batch_dims to `from_dict`." ) - batch_size_set = [] if batch_size is None else batch_size + batch_size_set = torch.Size([]) if batch_size is None else batch_size for key, value in list(input_dict.items()): if isinstance(value, (dict,)): # we don't know if another tensor of smaller size is coming From a151d41ef2dcd9e991b05f40f749a847a44c2cda Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 4 Dec 2023 11:46:40 +0000 Subject: [PATCH 2/2] amend --- tensordict/_td.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index df6804077..bd2d3142a 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1062,7 +1062,7 @@ def from_dict(cls, input_dict, batch_size=None, device=None, batch_dims=None): "Cannot pass both batch_size and batch_dims to `from_dict`." ) - batch_size_set = [] if batch_size is None else batch_size + batch_size_set = torch.Size(()) if batch_size is None else batch_size for key, value in list(input_dict.items()): if isinstance(value, (dict,)): # we don't know if another tensor of smaller size is coming