diff --git a/tensordict/memmap.py b/tensordict/memmap.py index 3d7bfee73..7dee50de5 100644 --- a/tensordict/memmap.py +++ b/tensordict/memmap.py @@ -9,6 +9,7 @@ import mmap import os +import stat import sys import tempfile @@ -266,7 +267,7 @@ def from_tensor( result = result.view(shape) result = cls(result) result._handler = handler - result._filename = filename + result.filename = filename result.index = None result.parent_shape = shape if copy_data: @@ -321,7 +322,7 @@ def from_storage( tensor = cls(tensor) if filename is not None: - tensor._filename = filename + tensor.filename = filename elif handler is not None: tensor._handler = handler if index is not None: @@ -339,6 +340,17 @@ def filename(self): raise RuntimeError("The MemoryMappedTensor has no file associated.") return filename + @filename.setter + def filename(self, value): + if value is None and self._filename is None: + return + value = str(Path(value).absolute()) + if self._filename is not None and value != self._filename: + raise RuntimeError( + "the MemoryMappedTensor has already a filename associated." + ) + self._filename = value + @classmethod def empty_like(cls, input, *, filename=None): # noqa: D417 @@ -596,7 +608,7 @@ def empty(cls, *args, **kwargs): *offsets_strides, ) result = cls(result) - result._filename = filename + result.filename = filename return result return result @@ -712,6 +724,8 @@ def from_filename(cls, filename, dtype, shape, index=None): tensor. """ + writable = _is_writable(filename) + if isinstance(shape, torch.Tensor): func_offset_stride = getattr( torch, "_nested_compute_contiguous_strides_offsets", None @@ -724,9 +738,18 @@ def from_filename(cls, filename, dtype, shape, index=None): "nested tensors. Please upgrade to a more recent " "version." ) - tensor = torch.from_file( - str(filename), shared=True, dtype=dtype, size=shape.prod(-1).sum().int() - ) + if writable: + tensor = torch.from_file( + str(filename), + shared=True, + dtype=dtype, + size=shape.prod(-1).sum().int(), + ) + else: + with open(str(filename), "rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + tensor = torch.frombuffer(mm, dtype=dtype) + # mm.close() tensor = torch._nested_view_from_buffer( tensor, shape, @@ -734,14 +757,21 @@ def from_filename(cls, filename, dtype, shape, index=None): ) else: shape = torch.Size(shape) - tensor = torch.from_file( - str(filename), shared=True, dtype=dtype, size=shape.numel() - ).view(shape) + # whether the file already existed + if writable: + tensor = torch.from_file( + str(filename), shared=True, dtype=dtype, size=shape.numel() + ) + else: + with open(str(filename), "rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + tensor = torch.frombuffer(mm, dtype=dtype) + tensor = tensor.view(shape) if index is not None: tensor = tensor[index] out = cls(tensor) - out._filename = filename + out.filename = filename out._handler = None out.index = index out.parent_shape = shape @@ -787,7 +817,7 @@ def from_handler(cls, handler, dtype, shape, index=None): if index is not None: out = out[index] out = cls(out) - out._filename = None + out.filename = None out._handler = handler out.index = index out.parent_shape = shape @@ -880,7 +910,7 @@ def _index_wrap(self, tensor, item, check=False): return tensor tensor = MemoryMappedTensor(tensor) tensor._handler = getattr(self, "_handler", None) - tensor._filename = getattr(self, "_filename", None) + tensor.filename = getattr(self, "_filename", None) tensor.index = item tensor.parent_shape = getattr(self, "parent_shape", None) return tensor @@ -1038,3 +1068,12 @@ def _unbind(tensor, dim): @implements_for_memmap(torch.chunk) def _chunk(input, chunks, dim=0): return input.chunk(chunks, dim=dim) + + +def _is_writable(file_path): + file_path = str(file_path) + if os.path.exists(file_path): + st = os.stat(file_path) + return bool(st.st_mode & (stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH)) + # Assume that the file can be written in the directory + return True diff --git a/test/test_memmap.py b/test/test_memmap.py index 0452ae262..27f099508 100644 --- a/test/test_memmap.py +++ b/test/test_memmap.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. import argparse import gc +import os +import stat from contextlib import nullcontext from pathlib import Path @@ -12,7 +14,7 @@ from _utils_internal import get_available_devices from tensordict import TensorDict -from tensordict.memmap import MemoryMappedTensor +from tensordict.memmap import _is_writable, MemoryMappedTensor from torch import multiprocessing as mp TIMEOUT = 100 @@ -157,7 +159,7 @@ def test_zeros(self, shape, dtype, device, tmp_path, from_path, shape_arg): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 0).all() @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) @@ -191,7 +193,7 @@ def test_ones(self, shape, dtype, device, tmp_path, from_path, shape_arg): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 1).all() @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) @@ -225,7 +227,7 @@ def test_empty(self, shape, dtype, device, tmp_path, from_path, shape_arg): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) @pytest.mark.parametrize("shape_arg", ["expand", "arg", "kwarg"]) def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg): @@ -258,7 +260,7 @@ def test_full(self, shape, dtype, device, tmp_path, from_path, shape_arg): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 2).all() def test_zeros_like(self, shape, dtype, device, tmp_path, from_path): @@ -272,7 +274,7 @@ def test_zeros_like(self, shape, dtype, device, tmp_path, from_path): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 0).all() def test_ones_like(self, shape, dtype, device, tmp_path, from_path): @@ -286,7 +288,7 @@ def test_ones_like(self, shape, dtype, device, tmp_path, from_path): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 1).all() def test_full_like(self, shape, dtype, device, tmp_path, from_path): @@ -300,7 +302,7 @@ def test_full_like(self, shape, dtype, device, tmp_path, from_path): if dtype is not None: assert t.dtype is dtype if filename is not None: - assert t.filename == filename + assert t.filename == str(Path(filename).absolute()) assert (t == 2).all() def test_from_filename(self, shape, dtype, device, tmp_path, from_path): @@ -715,6 +717,80 @@ def test_save_td_with_nested(self, tmpdir): assert (td[i, j] == tdsave[i, j]).all() +class TestReadWrite: + def test_read_only(self, tmpdir): + tmpdir = Path(tmpdir) + file_path = tmpdir / "elt.mmap" + mmap = MemoryMappedTensor.from_filename( + filename=file_path, shape=[2, 3], dtype=torch.float64 + ) + mmap.copy_(torch.arange(6).view(2, 3)) + + file_path = str(file_path.absolute()) + + assert _is_writable(file_path) + # Modify the permissions field to set the desired permissions + new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC + + # change permission + os.chmod(file_path, new_permissions) + + # Get the current file status + assert not _is_writable(file_path) + + del mmap + + # load file + mmap = MemoryMappedTensor.from_filename( + filename=file_path, shape=[2, 3], dtype=torch.float64 + ) + assert (mmap.reshape(-1) == torch.arange(6)).all() + + @pytest.mark.skipif(not HAS_NESTED_TENSOR, reason="Nested tensor incomplete") + def test_read_only_nested(self, tmpdir): + tmpdir = Path(tmpdir) + file_path = tmpdir / "elt.mmap" + data = MemoryMappedTensor.from_tensor(torch.arange(26), filename=file_path) + mmap = MemoryMappedTensor.from_storage( + data.untyped_storage(), + filename=file_path, + shape=torch.tensor([[2, 3], [4, 5]]), + dtype=data.dtype, + ) + + file_path = str(file_path.absolute()) + assert _is_writable(file_path) + + # Modify the permissions field to set the desired permissions + new_permissions = stat.S_IREAD # | stat.S_IWRITE | stat.S_IEXEC + + # change permission + os.chmod(file_path, new_permissions) + + # Get the current file status + assert not _is_writable(file_path) + + # load file + mmap1 = MemoryMappedTensor.from_filename( + filename=file_path, shape=torch.tensor([[2, 3], [4, 5]]), dtype=data.dtype + ) + assert (mmap1[0].view(-1) == torch.arange(6)).all() + assert (mmap1[1].view(-1) == torch.arange(6, 26)).all() + # test filename + assert mmap1.filename == mmap.filename + assert mmap1.filename == data.filename + assert mmap1.filename == data.untyped_storage().filename + with pytest.raises(AssertionError): + assert mmap1.untyped_storage().filename == data.untyped_storage().filename + + os.chmod(str(file_path), 0o444) + data.fill_(0) + os.chmod(str(file_path), 0o444) + + assert (mmap1[0].view(-1) == 0).all() + assert (mmap1[1].view(-1) == 0).all() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)