From 6d654adf879669eb4887c0c091832f4de8c38ce7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 17 Oct 2024 09:38:56 +0100 Subject: [PATCH] [Feature] param_count ghstack-source-id: af38095667859b9bea8d2e1bbd9d482b88db8c62 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1046 --- tensordict/base.py | 49 +++++++++++++++++++++++++++++++++++++++++ test/test_tensordict.py | 10 +++++++++ 2 files changed, 59 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 48b306520..6e994b2c2 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3142,6 +3142,55 @@ def _set_device(self, device: torch.device) -> T: value._set_device(device=device) return self + @cache # noqa: B019 + def param_count(self, *, count_duplicates: bool = True) -> int: + """Counts the number of parameters (total number of indexable items), accounting for tensors only. + + Keyword Args: + count_duplicates (bool): Whether to count duplicated tensor as independent or not. + If ``False``, only strictly identical tensors will be discarded (same views but different + ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed + to be a single copy). + + """ + vals = self._values_list(True, True) + total = 0 + if not count_duplicates: + vals = set(vals) + for v in vals: + total += v.numel() + return total + + @cache # noqa: B019 + def bytes(self, *, count_duplicates: bool = True) -> int: + """Counts the number of bytes of the contained tensors. + + Keyword Args: + count_duplicates (bool): Whether to count duplicated tensor as independent or not. + If ``False``, only strictly identical tensors will be discarded (same views but different + ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed + to be a single copy). + + """ + vals = self._values_list(True, True) + total = 0 + if not count_duplicates: + vals = set(vals) + for v in vals: + if v.is_nested: + if not v.layout == torch.jagged: + raise RuntimeError( + "NTs that are not jagged are not supported by the bytes method. Please use the jagged layout instead " + "or raise and issue on https://github.com/pytorch/tensordict/issues instead." + ) + total += v._values.numel() * v._values.dtype.itemsize + total += v._offsets.numel() * v._offsets.dtype.itemsize + if v._lengths is not None: + total += v._lengths.numel() * v._lengths.dtype.itemsize + else: + total += v.numel() * v.dtype.itemsize + return total + def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T: """Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e67b48416..4b11b405c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1903,6 +1903,16 @@ def test_pad_sequence_pad_dim1(self, make_mask): else: assert "masks" not in padded_td.keys() + @pytest.mark.parametrize("count_duplicates", [False, True]) + def test_param_count(self, count_duplicates): + td = TensorDict(a=torch.randn(3), b=torch.randn(6)) + td["c"] = td["a"] + assert len(td._values_list(True, True)) == 3 + if count_duplicates: + assert td.param_count(count_duplicates=count_duplicates) == 12 + else: + assert td.param_count(count_duplicates=count_duplicates) == 9 + @pytest.mark.parametrize("device", get_available_devices()) def test_permute(self, device): torch.manual_seed(1)