Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing 4D FrameBatch now returns FrameBatch #296

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ class FrameBatch(Iterable):
def __post_init__(self):
# This is called after __init__() when a FrameBatch is created. We can
# run input validation checks here.
if self.data.ndim < 4:
if self.data.ndim < 3:
raise ValueError(
f"data must be at least 4-dimensional. Got {self.data.shape = } "
"For 3-dimensional data, create a Frame object instead."
f"data must be at least 3-dimensional, got {self.data.shape = }"
)

leading_dims = self.data.shape[:-3]
Expand All @@ -83,33 +82,22 @@ def __post_init__(self):
f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }."
)

def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]:
cls = Frame if self.data.ndim == 4 else FrameBatch
def __iter__(self) -> Iterator["FrameBatch"]:
for data, pts_seconds, duration_seconds in zip(
self.data, self.pts_seconds, self.duration_seconds
):
yield cls(
yield FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)

def __getitem__(self, key) -> Union["FrameBatch", Frame]:
data = self.data[key]
pts_seconds = self.pts_seconds[key]
duration_seconds = self.duration_seconds[key]
if self.data.ndim == 4:
return Frame(
data=data,
pts_seconds=float(pts_seconds.item()),
duration_seconds=float(duration_seconds.item()),
)
else:
return FrameBatch(
data=data,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
)
def __getitem__(self, key) -> "FrameBatch":
return FrameBatch(
data=self.data[key],
pts_seconds=self.pts_seconds[key],
duration_seconds=self.duration_seconds[key],
)

def __len__(self):
return len(self.data)
Expand Down
36 changes: 28 additions & 8 deletions test/test_frame_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,18 @@ def test_frame_error():


def test_framebatch_error():
with pytest.raises(ValueError, match="data must be at least 4-dimensional"):
with pytest.raises(ValueError, match="data must be at least 3-dimensional"):
FrameBatch(
data=torch.rand(2, 3),
pts_seconds=torch.rand(1),
duration_seconds=torch.rand(1),
)

# Note: this is expected to fail because pts_seconds and duration_seconds
# are expected to have a shape of size([]) instead of size([1]).
with pytest.raises(
ValueError, match="leading dimensions of the inputs do not match"
):
FrameBatch(
data=torch.rand(1, 2, 3),
pts_seconds=torch.rand(1),
Expand Down Expand Up @@ -82,10 +93,14 @@ def test_framebatch_iteration():
assert sub_fb.pts_seconds.shape == (N,)
assert sub_fb.duration_seconds.shape == (N,)
for frame in sub_fb:
assert isinstance(frame, Frame)
assert isinstance(frame, FrameBatch)
assert frame.data.shape == (C, H, W)
assert isinstance(frame.pts_seconds, float)
assert isinstance(frame.duration_seconds, float)
# pts_seconds and duration_seconds are 0-dim tensors but they still
# contain a value
assert frame.pts_seconds.shape == tuple()
assert frame.duration_seconds.shape == tuple()
frame.pts_seconds.item()
frame.duration_seconds.item()

# Check unpacking behavior
first_sub_fb, *_ = fb
Expand All @@ -107,10 +122,15 @@ def test_framebatch_indexing():
assert fb[i].pts_seconds.shape == (N,)
assert fb[i].duration_seconds.shape == (N,)
for j in range(len(fb[i])):
assert isinstance(fb[i][j], Frame)
assert fb[i][j].data.shape == (C, H, W)
assert isinstance(fb[i][j].pts_seconds, float)
assert isinstance(fb[i][j].duration_seconds, float)
frame = fb[i][j]
assert isinstance(frame, FrameBatch)
assert frame.data.shape == (C, H, W)
# pts_seconds and duration_seconds are 0-dim tensors but they still
# contain a value
assert frame.pts_seconds.shape == tuple()
assert frame.duration_seconds.shape == tuple()
frame.pts_seconds.item()
frame.duration_seconds.item()

fb_fancy = fb[torch.arange(3)]
assert isinstance(fb_fancy, FrameBatch)
Expand Down
Loading