Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] param_count #1046

Merged
merged 4 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3142,6 +3142,92 @@ def _set_device(self, device: torch.device) -> T:
value._set_device(device=device)
return self

@cache # noqa: B019
def param_count(self, *, count_duplicates: bool = True) -> int:
"""Counts the number of parameters (total number of indexable items), accounting for tensors only.

Keyword Args:
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
If ``False``, only strictly identical tensors will be discarded (same views but different
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
to be a single copy).

"""
vals = self._values_list(True, True)
total = 0
if not count_duplicates:
vals = set(vals)
for v in vals:
total += v.numel()
return total

@cache # noqa: B019
def bytes(self, *, count_duplicates: bool = True) -> int:
"""Counts the number of bytes of the contained tensors.

Keyword Args:
count_duplicates (bool): Whether to count duplicated tensor as independent or not.
If ``False``, only strictly identical tensors will be discarded (same views but different
ids from a common base tensor will be counted twice). Defaults to `True` (each tensor is assumed
to be a single copy).

"""
set_of_tensors = set() if not count_duplicates else []

def add(tensor):
if count_duplicates:
set_of_tensors.append(tensor)
else:
set_of_tensors.add(tensor)

def count_bytes(tensor):
if tensor.is_nested:
if not tensor.layout == torch.jagged:
raise RuntimeError(
"NTs that are not jagged are not supported by the bytes method. Please use the jagged layout instead "
"or raise and issue on https://github.com/pytorch/tensordict/issues instead."
)
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
if isinstance(tensor, torch.Tensor):
if isinstance(tensor, MemoryMappedTensor):
add(tensor)
return
if type(tensor) is not torch.Tensor:
try:
attrs, ctx = tensor.__tensor_flatten__()
for attr in attrs:
t = getattr(tensor, attr)
count_bytes(t)
return
except AttributeError:
warnings.warn(
"The sub-tensor doesn't ot have a __tensor_flatten__ attribute, making it "
"impossible to count the bytes it contains. Falling back on regular count.",
category=UserWarning,
)
count_bytes(torch.as_tensor(tensor))
return

grad = getattr(tensor, "grad", None)
if grad is not None:
count_bytes(grad)
count_bytes(tensor.data)
else:
add(tensor)
return

vals = self._values_list(True, True)
for v in vals:
count_bytes(v)
total = 0
for tensor in set_of_tensors:
total += tensor.numel() * tensor.dtype.itemsize
return total

def pin_memory(self, num_threads: int | None = None, inplace: bool = False) -> T:
"""Calls :meth:`~torch.Tensor.pin_memory` on the stored tensors.

Expand Down
48 changes: 47 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@
mp_ctx = "fork" if (not torch.cuda.is_available() and not _IS_WINDOWS) else "spawn"


@pytest.fixture
def device_fixture():
device = torch.get_default_device()
if torch.cuda.is_available():
torch.set_default_device(torch.device("cuda:0"))
elif torch.backends.mps.is_available():
torch.set_default_device(torch.device("mps:0"))
yield
torch.set_default_device(device)


def _compare_tensors_identity(td0, td1):
if isinstance(td0, LazyStackedTensorDict):
if not isinstance(td1, LazyStackedTensorDict):
Expand Down Expand Up @@ -242,7 +253,32 @@ def test_batchsize_reset(self):
td_u.batch_size = [1]
td_u.to_tensordict().batch_size = [1]

def test_depth(ggself):
@pytest.mark.parametrize("count_duplicates", [False, True])
def test_bytes(self, count_duplicates, device_fixture):
tensor = torch.zeros(3)
tensor_with_grad = torch.ones(3, requires_grad=True)
(tensor_with_grad + 1).sum().backward()
v = torch.ones(3) * 2 # 12 bytes
offsets = torch.tensor([0, 1, 3]) # 24 bytes
lengths = torch.tensor([1, 2]) # 16 bytes
njt = torch.nested.nested_tensor_from_jagged(
v, offsets, lengths=lengths
) # 52 bytes
tricky = torch.nested.nested_tensor_from_jagged(
tensor, offsets, lengths=lengths
) # 52 bytes or 0
td = TensorDict(
tensor=tensor, # 3 * 4 = 12 bytes
tensor_with_grad=tensor_with_grad, # 3 * 4 * 2 = 24 bytes
njt=njt, # 32
tricky=tricky, # 32 or 0
)
if count_duplicates:
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 52
else:
assert td.bytes(count_duplicates=count_duplicates) == 12 + 24 + 52 + 0

def test_depth(self):
td = TensorDict({"a": {"b": {"c": {"d": 0}, "e": 0}, "f": 0}, "g": 0}).lock_()
assert td.depth == 3
with td.unlock_():
Expand Down Expand Up @@ -1903,6 +1939,16 @@ def test_pad_sequence_pad_dim1(self, make_mask):
else:
assert "masks" not in padded_td.keys()

@pytest.mark.parametrize("count_duplicates", [False, True])
def test_param_count(self, count_duplicates):
td = TensorDict(a=torch.randn(3), b=torch.randn(6))
td["c"] = td["a"]
assert len(td._values_list(True, True)) == 3
if count_duplicates:
assert td.param_count(count_duplicates=count_duplicates) == 12
else:
assert td.param_count(count_duplicates=count_duplicates) == 9

@pytest.mark.parametrize("device", get_available_devices())
def test_permute(self, device):
torch.manual_seed(1)
Expand Down
Loading