From 58c2b9123a81bf8e96304eaacc00921c57cb0e82 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 09:30:23 +0100 Subject: [PATCH] [Feature] param_count ghstack-source-id: af2b55a96550ef9f42dcf14fa5dbf4b62873f85c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1046 --- tensordict/base.py | 19 +++++++++++++++++++ test/test_tensordict.py | 10 ++++++++++ 2 files changed, 29 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 48b306520..798f0073a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3142,6 +3142,25 @@ def _set_device(self, device: torch.device) -> T: value._set_device(device=device) return self + @cache + def param_count(self, *, count_duplicates: bool = True) -> int: + """Counts the number of parameters (total number of indexable items), accounting for tensors only. + + Args: + count_duplicates (bool): Whether to count duplicated tensor as independent or not. + If ``False``, only strictly identical tensors will be discarded (same views but different + ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed + to be a single copy). + + """ + vals = self._values_list(True, True) + total = 0 + if not count_duplicates: + vals = set(vals) + for v in vals: + total += v.numel() + return total + def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T: """Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e67b48416..4b11b405c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1903,6 +1903,16 @@ def test_pad_sequence_pad_dim1(self, make_mask): else: assert "masks" not in padded_td.keys() + @pytest.mark.parametrize("count_duplicates", [False, True]) + def test_param_count(self, count_duplicates): + td = TensorDict(a=torch.randn(3), b=torch.randn(6)) + td["c"] = td["a"] + assert len(td._values_list(True, True)) == 3 + if count_duplicates: + assert td.param_count(count_duplicates=count_duplicates) == 12 + else: + assert td.param_count(count_duplicates=count_duplicates) == 9 + @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device): torch.manual_seed(1)