Skip to content

Commit

Permalink
Indexing 4D FrameBatch now returns FrameBatch
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 25, 2024
1 parent 9d7b240 commit df7a464
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
32 changes: 10 additions & 22 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ class FrameBatch(Iterable):
def __post_init__(self):
# This is called after __init__() when a FrameBatch is created. We can
# run input validation checks here.
if self.data.ndim < 4:
if self.data.ndim < 3:
raise ValueError(
f"data must be at least 4-dimensional. Got {self.data.shape = } "
"For 3-dimensional data, create a Frame object instead."
f"data must be at least 3-dimensional, got {self.data.shape = }"
)

leading_dims = self.data.shape[:-3]
Expand All @@ -83,33 +82,22 @@ def __post_init__(self):
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
)

def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
cls = Frame if self.data.ndim == 4 else FrameBatch
def __iter__(self) -> Iterator["FrameBatch"]:
for data, pts_seconds, duration_seconds in zip(
self.data, self.pts_seconds, self.duration_seconds
):
yield cls(
yield FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)

def __getitem__(self, key) -> Union["FrameBatch", Frame]:
data = self.data[key]
pts_seconds = self.pts_seconds[key]
duration_seconds = self.duration_seconds[key]
if self.data.ndim == 4:
return Frame(
data=data,
pts_seconds=float(pts_seconds.item()),
duration_seconds=float(duration_seconds.item()),
)
else:
return FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)
def __getitem__(self, key) -> "FrameBatch":
return FrameBatch(
data=self.data[key],
pts_seconds=self.pts_seconds[key],
duration_seconds=self.duration_seconds[key],
)

def __len__(self):
return len(self.data)
Expand Down
36 changes: 28 additions & 8 deletions test/test_frame_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ def test_frame_error():


def test_framebatch_error():
with pytest.raises(ValueError, match="data must be at least 4-dimensional"):
with pytest.raises(ValueError, match="data must be at least 3-dimensional"):
FrameBatch(
data=torch.rand(2, 3),
pts_seconds=torch.rand(1),
duration_seconds=torch.rand(1),
)

# Note: this is expected to fail because pts_seconds and duration_seconds
# are expected to have a shape of size([]) instead of size([1]).
with pytest.raises(
ValueError, match="leading dimensions of the inputs do not match"
):
FrameBatch(
data=torch.rand(1, 2, 3),
pts_seconds=torch.rand(1),
Expand Down Expand Up @@ -82,10 +93,14 @@ def test_framebatch_iteration():
assert sub_fb.pts_seconds.shape == (N,)
assert sub_fb.duration_seconds.shape == (N,)
for frame in sub_fb:
assert isinstance(frame, Frame)
assert isinstance(frame, FrameBatch)
assert frame.data.shape == (C, H, W)
assert isinstance(frame.pts_seconds, float)
assert isinstance(frame.duration_seconds, float)
# pts_seconds and duration_seconds are 0-dim tensors but they still
# contain a value
assert frame.pts_seconds.shape == tuple()
assert frame.duration_seconds.shape == tuple()
frame.pts_seconds.item()
frame.duration_seconds.item()

# Check unpacking behavior
first_sub_fb, *_ = fb
Expand All @@ -107,10 +122,15 @@ def test_framebatch_indexing():
assert fb[i].pts_seconds.shape == (N,)
assert fb[i].duration_seconds.shape == (N,)
for j in range(len(fb[i])):
assert isinstance(fb[i][j], Frame)
assert fb[i][j].data.shape == (C, H, W)
assert isinstance(fb[i][j].pts_seconds, float)
assert isinstance(fb[i][j].duration_seconds, float)
frame = fb[i][j]
assert isinstance(frame, FrameBatch)
assert frame.data.shape == (C, H, W)
# pts_seconds and duration_seconds are 0-dim tensors but they still
# contain a value
assert frame.pts_seconds.shape == tuple()
assert frame.duration_seconds.shape == tuple()
frame.pts_seconds.item()
frame.duration_seconds.item()

fb_fancy = fb[torch.arange(3)]
assert isinstance(fb_fancy, FrameBatch)
Expand Down

0 comments on commit df7a464

Please sign in to comment.