Skip to content

Commit

Permalink
[BugFix] Read-only compatibility in MemoryMappedTensor (#780)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored May 16, 2024
1 parent 1df73e7 commit a088b87
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 20 deletions.
63 changes: 51 additions & 12 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import mmap
import os
import stat

import sys
import tempfile
Expand Down Expand Up @@ -266,7 +267,7 @@ def from_tensor(
result = result.view(shape)
result = cls(result)
result._handler = handler
result._filename = filename
result.filename = filename
result.index = None
result.parent_shape = shape
if copy_data:
Expand Down Expand Up @@ -321,7 +322,7 @@ def from_storage(

tensor = cls(tensor)
if filename is not None:
tensor._filename = filename
tensor.filename = filename
elif handler is not None:
tensor._handler = handler
if index is not None:
Expand All @@ -339,6 +340,17 @@ def filename(self):
raise RuntimeError("The MemoryMappedTensor has no file associated.")
return filename

@filename.setter
def filename(self, value):
if value is None and self._filename is None:
return
value = str(Path(value).absolute())
if self._filename is not None and value != self._filename:
raise RuntimeError(
"the MemoryMappedTensor has already a filename associated."
)
self._filename = value

@classmethod
def empty_like(cls, input, *, filename=None):
# noqa: D417
Expand Down Expand Up @@ -596,7 +608,7 @@ def empty(cls, *args, **kwargs):
*offsets_strides,
)
result = cls(result)
result._filename = filename
result.filename = filename
return result
return result

Expand Down Expand Up @@ -712,6 +724,8 @@ def from_filename(cls, filename, dtype, shape, index=None):
tensor.
"""
writable = _is_writable(filename)

if isinstance(shape, torch.Tensor):
func_offset_stride = getattr(
torch, "_nested_compute_contiguous_strides_offsets", None
Expand All @@ -724,24 +738,40 @@ def from_filename(cls, filename, dtype, shape, index=None):
"nested tensors. Please upgrade to a more recent "
"version."
)
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.prod(-1).sum().int()
)
if writable:
tensor = torch.from_file(
str(filename),
shared=True,
dtype=dtype,
size=shape.prod(-1).sum().int(),
)
else:
with open(str(filename), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
tensor = torch.frombuffer(mm, dtype=dtype)
# mm.close()
tensor = torch._nested_view_from_buffer(
tensor,
shape,
*offsets_strides,
)
else:
shape = torch.Size(shape)
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.numel()
).view(shape)
# whether the file already existed
if writable:
tensor = torch.from_file(
str(filename), shared=True, dtype=dtype, size=shape.numel()
)
else:
with open(str(filename), "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
tensor = torch.frombuffer(mm, dtype=dtype)
tensor = tensor.view(shape)

if index is not None:
tensor = tensor[index]
out = cls(tensor)
out._filename = filename
out.filename = filename
out._handler = None
out.index = index
out.parent_shape = shape
Expand Down Expand Up @@ -787,7 +817,7 @@ def from_handler(cls, handler, dtype, shape, index=None):
if index is not None:
out = out[index]
out = cls(out)
out._filename = None
out.filename = None
out._handler = handler
out.index = index
out.parent_shape = shape
Expand Down Expand Up @@ -880,7 +910,7 @@ def _index_wrap(self, tensor, item, check=False):
return tensor
tensor = MemoryMappedTensor(tensor)
tensor._handler = getattr(self, "_handler", None)
tensor._filename = getattr(self, "_filename", None)
tensor.filename = getattr(self, "_filename", None)
tensor.index = item
tensor.parent_shape = getattr(self, "parent_shape", None)
return tensor
Expand Down Expand Up @@ -1038,3 +1068,12 @@ def _unbind(tensor, dim):
@implements_for_memmap(torch.chunk)
def _chunk(input, chunks, dim=0):
return input.chunk(chunks, dim=dim)


def _is_writable(file_path):
file_path = str(file_path)
if os.path.exists(file_path):
st = os.stat(file_path)
return bool(st.st_mode & (stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH))
# Assume that the file can be written in the directory
return True
92 changes: 84 additions & 8 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
import argparse
import gc
import os
import stat
from contextlib import nullcontext
from pathlib import Path

Expand All @@ -12,7 +14,7 @@
from _utils_internal import get_available_devices
from tensordict import TensorDict

from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap import _is_writable, MemoryMappedTensor
from torch import multiprocessing as mp

TIMEOUT = 100
Expand Down Expand Up @@ -157,7 +159,7 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 0).all()

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
Expand Down Expand Up @@ -191,7 +193,7 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 1).all()

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
Expand Down Expand Up @@ -225,7 +227,7 @@ def test_empty(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())

@pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"])
def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg):
Expand Down Expand Up @@ -258,7 +260,7 @@ def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 2).all()

def test_zeros_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -272,7 +274,7 @@ def test_zeros_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 0).all()

def test_ones_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -286,7 +288,7 @@ def test_ones_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 1).all()

def test_full_like(self, shape, dtype, device, tmp_path, from_path):
Expand All @@ -300,7 +302,7 @@ def test_full_like(self, shape, dtype, device, tmp_path, from_path):
if dtype is not None:
assert t.dtype is dtype
if filename is not None:
assert t.filename == filename
assert t.filename == str(Path(filename).absolute())
assert (t == 2).all()

def test_from_filename(self, shape, dtype, device, tmp_path, from_path):
Expand Down Expand Up @@ -715,6 +717,80 @@ def test_save_td_with_nested(self, tmpdir):
assert (td[i, j] == tdsave[i, j]).all()


class TestReadWrite:
def test_read_only(self, tmpdir):
tmpdir = Path(tmpdir)
file_path = tmpdir / "elt.mmap"
mmap = MemoryMappedTensor.from_filename(
filename=file_path, shape=[2, 3], dtype=torch.float64
)
mmap.copy_(torch.arange(6).view(2, 3))

file_path = str(file_path.absolute())

assert _is_writable(file_path)
# Modify the permissions field to set the desired permissions
new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC

# change permission
os.chmod(file_path, new_permissions)

# Get the current file status
assert not _is_writable(file_path)

del mmap

# load file
mmap = MemoryMappedTensor.from_filename(
filename=file_path, shape=[2, 3], dtype=torch.float64
)
assert (mmap.reshape(-1) == torch.arange(6)).all()

@pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete")
def test_read_only_nested(self, tmpdir):
tmpdir = Path(tmpdir)
file_path = tmpdir / "elt.mmap"
data = MemoryMappedTensor.from_tensor(torch.arange(26), filename=file_path)
mmap = MemoryMappedTensor.from_storage(
data.untyped_storage(),
filename=file_path,
shape=torch.tensor([[2, 3], [4, 5]]),
dtype=data.dtype,
)

file_path = str(file_path.absolute())
assert _is_writable(file_path)

# Modify the permissions field to set the desired permissions
new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC

# change permission
os.chmod(file_path, new_permissions)

# Get the current file status
assert not _is_writable(file_path)

# load file
mmap1 = MemoryMappedTensor.from_filename(
filename=file_path, shape=torch.tensor([[2, 3], [4, 5]]), dtype=data.dtype
)
assert (mmap1[0].view(-1) == torch.arange(6)).all()
assert (mmap1[1].view(-1) == torch.arange(6, 26)).all()
# test filename
assert mmap1.filename == mmap.filename
assert mmap1.filename == data.filename
assert mmap1.filename == data.untyped_storage().filename
with pytest.raises(AssertionError):
assert mmap1.untyped_storage().filename == data.untyped_storage().filename

os.chmod(str(file_path), 0o444)
data.fill_(0)
os.chmod(str(file_path), 0o444)

assert (mmap1[0].view(-1) == 0).all()
assert (mmap1[1].view(-1) == 0).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

2 comments on commit a088b87

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a088b87 Previous: 1df73e7 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 74134.9917302418 iter/sec (stddev: 7.019309624229432e-7) 176258.48845571827 iter/sec (stddev: 3.8876193917935657e-7) 2.38
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 73835.57288766393 iter/sec (stddev: 9.982441798151835e-7) 173832.73401344978 iter/sec (stddev: 3.5954403248480965e-7) 2.35

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a088b87 Previous: 1df73e7 Ratio
benchmarks/common/common_ops_test.py::test_keys_nested 5869.3477199516765 iter/sec (stddev: 0.0000035339182887987186) 14571.179609164024 iter/sec (stddev: 0.0000019502673289121764) 2.48
benchmarks/common/common_ops_test.py::test_keys_nested_locked 5678.566966588077 iter/sec (stddev: 0.000008516000526452174) 13477.544239686018 iter/sec (stddev: 0.000023380622280381144) 2.37
benchmarks/common/common_ops_test.py::test_keys_nested_leaf 6835.796985141076 iter/sec (stddev: 0.00000318514424818833) 16956.88545830509 iter/sec (stddev: 0.0000017183832761358163) 2.48
benchmarks/common/common_ops_test.py::test_keys_stack_nested 6048.485125941703 iter/sec (stddev: 0.00000362803650388799) 14583.993213143836 iter/sec (stddev: 0.0000018010183280050648) 2.41
benchmarks/common/common_ops_test.py::test_keys_stack_nested_leaf 7102.472964498791 iter/sec (stddev: 0.0000027884572208407356) 16846.466688748893 iter/sec (stddev: 0.000001580719178442906) 2.37
benchmarks/common/common_ops_test.py::test_keys_stack_nested_locked 5866.635779383213 iter/sec (stddev: 0.000003969429791021374) 13673.682280453733 iter/sec (stddev: 0.0000020473479542051454) 2.33
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 149191.7009726113 iter/sec (stddev: 5.262498285926441e-7) 329107.80686765345 iter/sec (stddev: 3.015061042641561e-7) 2.21
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 149327.43837619558 iter/sec (stddev: 5.131131664988721e-7) 325499.0869180446 iter/sec (stddev: 3.153770200200213e-7) 2.18

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.