Skip to content

Commit

Permalink
[Feature] param_count
Browse files Browse the repository at this point in the history
ghstack-source-id: af2b55a96550ef9f42dcf14fa5dbf4b62873f85c
Pull Request resolved: #1046
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent ee49fc7 commit 58c2b91
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,25 @@ def _set_device(self, device: torch.device) -> T:
value._set_device(device=device)
return self

@cache
def param_count(self, *, count_duplicates: bool = True) -> int:
"""Counts the number of parameters (total number of indexable items), accounting for tensors only.
Args:
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
If ``False``, only strictly identical tensors will be discarded (same views but different
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
to be a single copy).
"""
vals = self._values_list(True, True)
total = 0
if not count_duplicates:
vals = set(vals)
for v in vals:
total += v.numel()
return total

def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T:
"""Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors.
Expand Down
10 changes: 10 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,16 @@ def test_pad_sequence_pad_dim1(self, make_mask):
else:
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("count_duplicates", [False, True])
def test_param_count(self, count_duplicates):
td = TensorDict(a=torch.randn(3), b=torch.randn(6))
td["c"] = td["a"]
assert len(td._values_list(True, True)) == 3
if count_duplicates:
assert td.param_count(count_duplicates=count_duplicates) == 12
else:
assert td.param_count(count_duplicates=count_duplicates) == 9

@pytest.mark.parametrize("device", get_available_devices())
def test_permute(self, device):
torch.manual_seed(1)
Expand Down

0 comments on commit 58c2b91

Please sign in to comment.