Skip to content

Commit

Permalink
[Feature] param_count
Browse files Browse the repository at this point in the history
ghstack-source-id: 69b76ae5dfd1ee743b0592fc2608a2bbafc945d6
Pull Request resolved: #1046
  • Loading branch information
vmoens committed Oct 17, 2024
1 parent ee49fc7 commit 726c09b
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 1 deletion.
86 changes: 86 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,92 @@ def _set_device(self, device: torch.device) -> T:
value._set_device(device=device)
return self

@cache # noqa: B019
def param_count(self, *, count_duplicates: bool = True) -> int:
"""Counts the number of parameters (total number of indexable items), accounting for tensors only.
Keyword Args:
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
If ``False``, only strictly identical tensors will be discarded (same views but different
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
to be a single copy).
"""
vals = self._values_list(True, True)
total = 0
if not count_duplicates:
vals = set(vals)
for v in vals:
total += v.numel()
return total

@cache # noqa: B019
def bytes(self, *, count_duplicates: bool = True) -> int:
"""Counts the number of bytes of the contained tensors.
Keyword Args:
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
If ``False``, only strictly identical tensors will be discarded (same views but different
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
to be a single copy).
"""
set_of_tensors = set() if not count_duplicates else []

def add(tensor):
if count_duplicates:
set_of_tensors.append(tensor)
else:
set_of_tensors.add(tensor)

def count_bytes(tensor):
if tensor.is_nested:
if not tensor.layout == torch.jagged:
raise RuntimeError(
"NTs that are not jagged are not supported by the bytes method. Please use the jagged layout instead "
"or raise and issue on https://github.com/pytorch/tensordict/issues instead."
)
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, MemoryMappedTensor):
add(tensor)
return
if type(tensor) is not torch.Tensor:
try:
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
except AttributeError:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return

grad = getattr(tensor, "grad", None)
if grad is not None:
count_bytes(grad)
count_bytes(tensor.data)
else:
add(tensor)
return

vals = self._values_list(True, True)
for v in vals:
count_bytes(v)
total = 0
for tensor in set_of_tensors:
total += tensor.numel() * tensor.dtype.itemsize
return total

def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T:
"""Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors.
Expand Down
48 changes: 47 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@
mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"


@pytest.fixture
def device_fixture():
device = torch.get_default_device()
if torch.cuda.is_available():
torch.set_default_device(torch.device("cuda:0"))
elif torch.backends.mps.is_available():
torch.set_default_device(torch.device("mps:0"))
yield
torch.set_default_device(device)


def _compare_tensors_identity(td0, td1):
if isinstance(td0, LazyStackedTensorDict):
if not isinstance(td1, LazyStackedTensorDict):
Expand Down Expand Up @@ -242,7 +253,32 @@ def test_batchsize_reset(self):
td_u.batch_size = [1]
td_u.to_tensordict().batch_size = [1]

def test_depth(ggself):
@pytest.mark.parametrize("count_duplicates", [False, True])
def test_bytes(self, count_duplicates, device_fixture):
tensor = torch.zeros(3)
tensor_with_grad = torch.ones(3, requires_grad=True)
(tensor_with_grad + 1).sum().backward()
v = torch.ones(3) * 2 # 12 bytes
offsets = torch.tensor([0, 1, 3]) # 24 bytes
lengths = torch.tensor([1, 2]) # 16 bytes
njt = torch.nested.nested_tensor_from_jagged(
v, offsets, lengths=lengths
) # 52 bytes
tricky = torch.nested.nested_tensor_from_jagged(
tensor, offsets, lengths=lengths
) # 52 bytes or 0
td = TensorDict(
tensor=tensor, # 3 * 4 = 12 bytes
tensor_with_grad=tensor_with_grad, # 3 * 4 * 2 = 24 bytes
njt=njt, # 32
tricky=tricky, # 32 or 0
)
if count_duplicates:
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 52
else:
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0

def test_depth(self):
td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_()
assert td.depth == 3
with td.unlock_():
Expand Down Expand Up @@ -1903,6 +1939,16 @@ def test_pad_sequence_pad_dim1(self, make_mask):
else:
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("count_duplicates", [False, True])
def test_param_count(self, count_duplicates):
td = TensorDict(a=torch.randn(3), b=torch.randn(6))
td["c"] = td["a"]
assert len(td._values_list(True, True)) == 3
if count_duplicates:
assert td.param_count(count_duplicates=count_duplicates) == 12
else:
assert td.param_count(count_duplicates=count_duplicates) == 9

@pytest.mark.parametrize("device", get_available_devices())
def test_permute(self, device):
torch.manual_seed(1)
Expand Down

0 comments on commit 726c09b

Please sign in to comment.