Skip to content

Commit

Permalink
[BugFix] Refactor map and map_iter (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 11, 2024
1 parent fb4b629 commit 02ff686
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies:
- pytest-mock
- pytest-instafail
- pytest-rerunfailures
- pytest-timeout
- expecttest
- coverage
- h5py
Expand Down
2 changes: 1 addition & 1 deletion .github/unittest/linux/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
export MKL_THREADING_LAYER=GNU

coverage run -m pytest test/smoke_test.py -v --durations 20
coverage run -m pytest --instafail -v --durations 20
coverage run -m pytest --instafail -v --durations 20 --timeout 120
coverage run -m pytest ./benchmarks --instafail -v --durations 20
coverage xml -i
10 changes: 8 additions & 2 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,13 @@ def map(
initargs=(seed, queue, worker_threads),
maxtasksperchild=max_tasks_per_child,
) as pool:
return self.map(fn, dim=dim, chunksize=chunksize, pool=pool)
return self.map(
fn,
dim=dim,
chunksize=chunksize,
pool=pool,
index_with_generator=index_with_generator,
)
num_workers = pool._processes
dim_orig = dim
if dim < 0:
Expand Down Expand Up @@ -863,7 +869,7 @@ def newfn(item_and_out):
for item in imap:
if item is not None:
if out is not None:
if chunksize:
if chunksize != 0:
end = start + item.shape[dim]
chunk = slice(start, end)
out[chunk].update_(item)
Expand Down
50 changes: 32 additions & 18 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8827,6 +8827,7 @@ def test_map_seed(self):
pytest.skip(
reason="Using max_tasks_per_child is unstable and can cause multiple processes to start over even though all jobs are completed",
)
gc.collect()

if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
Expand Down Expand Up @@ -8870,6 +8871,7 @@ def test_map_seed(self):
)

def test_map_seed_single(self):
gc.collect()
# A cheap version of the previous test
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
Expand Down Expand Up @@ -8916,35 +8918,42 @@ def test_map_seed_single(self):
@pytest.mark.parametrize("h5", [False, True])
@pytest.mark.parametrize("has_out", [False, True])
def test_index_with_generator(self, chunksize, num_chunks, h5, has_out, tmpdir):
gc.collect()
input = TensorDict({"a": torch.arange(10), "b": torch.arange(10)}, [10])
if h5:
tmpdir = pathlib.Path(tmpdir)
input = input.to_h5(tmpdir / "file.h5")
input_h5 = input.to_h5(tmpdir / "file.h5")
assert input.shape == input_h5.shape
input = input_h5
if has_out:
output_generator = torch.zeros_like(self.selectfn(input.to_tensordict()))
output_split = torch.zeros_like(self.selectfn(input.to_tensordict()))
else:
output_generator = None
output_split = None
output_generator = input.map(
self.selectfn,
num_workers=2,
index_with_generator=True,
num_chunks=num_chunks,
chunksize=chunksize,
out=output_generator,
)
output_split = input.map(
self.selectfn,
num_workers=2,
index_with_generator=True,
num_chunks=num_chunks,
chunksize=chunksize,
out=output_split,
)
with mp.get_context("fork").Pool(2) as pool:
output_generator = input.map(
self.selectfn,
num_workers=2,
index_with_generator=True,
num_chunks=num_chunks,
chunksize=chunksize,
out=output_generator,
pool=pool,
)
output_split = input.map(
self.selectfn,
num_workers=2,
index_with_generator=True,
num_chunks=num_chunks,
chunksize=chunksize,
out=output_split,
pool=pool,
)
assert (output_generator == output_split).all()

def test_map_unbind(self):
gc.collect()
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
td0 = TensorDict({"0": 0}, [])
Expand All @@ -8961,6 +8970,7 @@ def _assert_is_memmap(data):

@pytest.mark.parametrize("chunksize", [0, 5])
def test_map_inplace(self, chunksize):
gc.collect()
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
# Tests that we can return None values
Expand All @@ -8978,6 +8988,7 @@ def selectfn(input):
"start_method", [None, "spawn" if torch.cuda.is_available() else "fork"]
)
def test_map_with_out(self, mmap, chunksize, tmpdir, start_method):
gc.collect()
tmpdir = Path(tmpdir)
input = TensorDict({"a": torch.arange(10), "b": torch.arange(10)}, [10])
if mmap:
Expand All @@ -9001,7 +9012,8 @@ def nontensor_check(cls, td):
)
return td

def test_non_tensor(self):
def test_map_non_tensor(self):
gc.collect()
# with NonTensorStack
td = TensorDict(
{"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10]
Expand All @@ -9026,6 +9038,7 @@ def _return_identical(td):
"chunksize,num_chunks", [[0, None], [11, None], [None, 11]]
)
def test_map_iter(self, chunksize, num_chunks, shuffle):
gc.collect()
torch.manual_seed(0)
td = TensorDict(
{
Expand Down Expand Up @@ -9075,6 +9088,7 @@ def test_map_iter(self, chunksize, num_chunks, shuffle):
"chunksize,num_chunks", [[0, None], [11, None], [None, 11]]
)
def test_map_iter_interrupt_early(self, chunksize, num_chunks, shuffle):
gc.collect()
torch.manual_seed(0)
td = TensorDict(
{
Expand Down

1 comment on commit 02ff686

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 02ff686 Previous: fb4b629 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 87980.42647302672 iter/sec (stddev: 0.000001344804855390614) 242300.6448413797 iter/sec (stddev: 3.6013501137612525e-7) 2.75
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 87548.33064392106 iter/sec (stddev: 7.625204277943041e-7) 237417.42519426672 iter/sec (stddev: 5.168093171660373e-7) 2.71
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.2222551311971874 iter/sec (stddev: 0.3192590881575344) 2.5027626194566333 iter/sec (stddev: 0.07068272770168096) 2.05

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.