From df7a4642391664fa86d7240588a56e76f063fd53 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 25 Oct 2024 13:19:52 +0100 Subject: [PATCH] Indexing 4D FrameBatch now returns FrameBatch --- src/torchcodec/_frame.py | 32 ++++++++++-------------------- test/test_frame_dataclasses.py | 36 ++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 8137c457..fd792ceb 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -68,10 +68,9 @@ class FrameBatch(Iterable): def __post_init__(self): # This is called after __init__() when a FrameBatch is created. We can # run input validation checks here. - if self.data.ndim < 4: + if self.data.ndim < 3: raise ValueError( - f"data must be at least 4-dimensional. Got {self.data.shape = } " - "For 3-dimensional data, create a Frame object instead." + f"data must be at least 3-dimensional, got {self.data.shape = }" ) leading_dims = self.data.shape[:-3] @@ -83,33 +82,22 @@ def __post_init__(self): f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }." ) - def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: - cls = Frame if self.data.ndim == 4 else FrameBatch + def __iter__(self) -> Iterator["FrameBatch"]: for data, pts_seconds, duration_seconds in zip( self.data, self.pts_seconds, self.duration_seconds ): - yield cls( + yield FrameBatch( data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds, ) - def __getitem__(self, key) -> Union["FrameBatch", Frame]: - data = self.data[key] - pts_seconds = self.pts_seconds[key] - duration_seconds = self.duration_seconds[key] - if self.data.ndim == 4: - return Frame( - data=data, - pts_seconds=float(pts_seconds.item()), - duration_seconds=float(duration_seconds.item()), - ) - else: - return FrameBatch( - data=data, - pts_seconds=pts_seconds, - duration_seconds=duration_seconds, - ) + def __getitem__(self, key) -> "FrameBatch": + return FrameBatch( + data=self.data[key], + pts_seconds=self.pts_seconds[key], + duration_seconds=self.duration_seconds[key], + ) def __len__(self): return len(self.data) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index 9b79b882..a9840f22 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -23,7 +23,18 @@ def test_frame_error(): def test_framebatch_error(): - with pytest.raises(ValueError, match="data must be at least 4-dimensional"): + with pytest.raises(ValueError, match="data must be at least 3-dimensional"): + FrameBatch( + data=torch.rand(2, 3), + pts_seconds=torch.rand(1), + duration_seconds=torch.rand(1), + ) + + # Note: this is expected to fail because pts_seconds and duration_seconds + # are expected to have a shape of size([]) instead of size([1]). + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): FrameBatch( data=torch.rand(1, 2, 3), pts_seconds=torch.rand(1), @@ -82,10 +93,14 @@ def test_framebatch_iteration(): assert sub_fb.pts_seconds.shape == (N,) assert sub_fb.duration_seconds.shape == (N,) for frame in sub_fb: - assert isinstance(frame, Frame) + assert isinstance(frame, FrameBatch) assert frame.data.shape == (C, H, W) - assert isinstance(frame.pts_seconds, float) - assert isinstance(frame.duration_seconds, float) + # pts_seconds and duration_seconds are 0-dim tensors but they still + # contain a value + assert frame.pts_seconds.shape == tuple() + assert frame.duration_seconds.shape == tuple() + frame.pts_seconds.item() + frame.duration_seconds.item() # Check unpacking behavior first_sub_fb, *_ = fb @@ -107,10 +122,15 @@ def test_framebatch_indexing(): assert fb[i].pts_seconds.shape == (N,) assert fb[i].duration_seconds.shape == (N,) for j in range(len(fb[i])): - assert isinstance(fb[i][j], Frame) - assert fb[i][j].data.shape == (C, H, W) - assert isinstance(fb[i][j].pts_seconds, float) - assert isinstance(fb[i][j].duration_seconds, float) + frame = fb[i][j] + assert isinstance(frame, FrameBatch) + assert frame.data.shape == (C, H, W) + # pts_seconds and duration_seconds are 0-dim tensors but they still + # contain a value + assert frame.pts_seconds.shape == tuple() + assert frame.duration_seconds.shape == tuple() + frame.pts_seconds.item() + frame.duration_seconds.item() fb_fancy = fb[torch.arange(3)] assert isinstance(fb_fancy, FrameBatch)