Skip to content

Commit

Permalink
[torchcodec] add slice support to SimpleVideoDecoder (#53)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #53

Adds slice support to `SimpleVideoDecoder`. This diff enables the following use-cases:

* `first_few_frames = decoder[0:3]`
* `last_few_frames = decoder[-3:]`
* `multiples_of_five = decoder[0:100:5]`
* `all_frames = decoder[:]`

Most of the heavy lifting in the implementation is being done by the previously-implemented `core.get_frames_in_range` and the Python slice object: https://docs.python.org/3/reference/datamodel.html#slice-objects.

Reviewed By: ahmadsharif1

Differential Revision: D59059431

fbshipit-source-id: d3c2aa3fea32f9adda43e6e9cf8e4ad32201281f
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 26, 2024
1 parent 1b8035e commit a57bbee
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 6 deletions.
29 changes: 24 additions & 5 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,8 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]):
def __len__(self) -> int:
return self._num_frames

def __getitem__(self, key: int) -> torch.Tensor:
if not isinstance(key, int):
raise TypeError(
f"Unsupported key type: {type(key)}. Supported type is int."
)
def _getitem_int(self, key: int) -> torch.Tensor:
assert isinstance(key, int)

if key < 0:
key += self._num_frames
Expand All @@ -51,6 +48,28 @@ def __getitem__(self, key: int) -> torch.Tensor:
self._decoder, frame_index=key, stream_index=self._stream_index
)

def _getitem_slice(self, key: slice) -> torch.Tensor:
assert isinstance(key, slice)

start, stop, step = key.indices(len(self))
return core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
start=start,
stop=stop,
step=step,
)

def __getitem__(self, key: Union[int, slice]) -> torch.Tensor:
if isinstance(key, int):
return self._getitem_int(key)
elif isinstance(key, slice):
return self._getitem_slice(key)

raise TypeError(
f"Unsupported key type: {type(key)}. Supported types are int and slice."
)

def __iter__(self) -> "SimpleVideoDecoder":
return self

Expand Down
82 changes: 82 additions & 0 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import pytest
import torch

from torchcodec.decoders import SimpleVideoDecoder

from ..test_utils import (
assert_equal,
EMPTY_REF_TENSOR,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
REF_DIMS,
)


Expand Down Expand Up @@ -48,6 +51,85 @@ def test_getitem_int(self):
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])

def test_getitem_slice(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frames0_9 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt")
for i in range(0, 9)
]

# Ensure that the degenerate case of a range of size 1 works; note that we get
# a tensor which CONTAINS a single frame, rather than a tensor that itself IS a
# single frame. Hence we have to access the 0th element of the return tensor.
slice_0 = decoder[0:1]
assert slice_0.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[0], slice_0[0])

slice_4 = decoder[4:5]
assert slice_4.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[4], slice_4[0])

slice_8 = decoder[8:9]
assert slice_8.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[8], slice_8[0])

ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
slice_180 = decoder[180:181]
assert slice_180.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frame180, slice_180[0])

# contiguous ranges
slice_frames0_9 = decoder[0:9]
assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS])
for i, slice_frame in enumerate(slice_frames0_9):
assert_equal(ref_frames0_9[i], slice_frame)

slice_frames4_8 = decoder[4:8]
assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS])
for i, slice_frame in enumerate(slice_frames4_8):
assert_equal(ref_frames0_9[i + 4], slice_frame)

# ranges with a stride
ref_frames15_35 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(15, 36, 5)
]
slice_frames15_35 = decoder[15:36:5]
assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS])
for i, slice_frame in enumerate(slice_frames15_35):
assert_equal(ref_frames15_35[i], slice_frame)

slice_frames0_9_2 = decoder[0:9:2]
assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS])
for i, slice_frame in enumerate(slice_frames0_9_2):
assert_equal(ref_frames0_9[i * 2], slice_frame)

# negative numbers in the slice
ref_frames386_389 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(386, 390)
]

slice_frames386_389 = decoder[-4:]
assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS])
for i, slice_frame in enumerate(slice_frames386_389):
assert_equal(ref_frames386_389[i], slice_frame)

# an empty range is valid!
empty_frame = decoder[5:5]
assert_equal(empty_frame, EMPTY_REF_TENSOR)

# slices that are out-of-range are also valid - they return an empty tensor
also_empty = decoder[10000:]
assert_equal(also_empty, EMPTY_REF_TENSOR)

# should be just a copy
all_frames = decoder[:]
assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS])
for sliced, ref in zip(all_frames, decoder):
assert_equal(sliced, ref)

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

Expand Down
3 changes: 2 additions & 1 deletion test/decoders/video_decoder_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ..test_utils import (
assert_equal,
EMPTY_REF_TENSOR,
get_reference_audio_path,
get_reference_video_path,
load_tensor_from_file,
Expand Down Expand Up @@ -161,7 +162,7 @@ def test_get_frames_in_range(self):

# an empty range is valid!
empty_frame = get_frames_in_range(decoder, stream_index=3, start=5, stop=5)
assert_equal(empty_frame, torch.empty((0, 270, 480, 3), dtype=torch.uint8))
assert_equal(empty_frame, EMPTY_REF_TENSOR)

def test_throws_exception_at_eof(self):
decoder = create_from_file(str(get_reference_video_path()))
Expand Down
8 changes: 8 additions & 0 deletions test/generate_reference_resources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ VIDEO_PATH=$RESOURCES_DIR/nasa_13013.mp4

# Important note: I used ffmpeg version 6.1.1 to generate these images. We
# must have the version that matches the one that we link against in the test.
# TODO: The first 10 frames are numbered starting from 1, so their name is one more
# than their index. This is confusing. We should unify the naming so files are
# named by their index. This will inovlve also updating the tests that load
# these files.
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,0)+eq(n\,1)+eq(n\,2)+eq(n\,3)+eq(n\,4)+eq(n\,5)+eq(n\,6)+eq(n\,7)+eq(n\,8)+eq(n\,9)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame%06d.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,15)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000015.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,20)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000020.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,25)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000025.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,30)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000030.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,35)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000035.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,386)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000386.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,387)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000387.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,388)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000388.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,389)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000389.bmp"
ffmpeg -y -ss 6.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.000000.bmp"
ffmpeg -y -ss 6.1 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.100000.bmp"
ffmpeg -y -ss 10.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time10.000000.bmp"
Expand Down
Binary file added test/resources/nasa_13013.mp4.frame000386.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000387.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000388.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000389.pt
Binary file not shown.
4 changes: 4 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

import torch

# The dimensions and type have to match the frames in our reference video.
REF_DIMS = (270, 480, 3)
EMPTY_REF_TENSOR = torch.empty([0, *REF_DIMS], dtype=torch.uint8)


def in_fbcode() -> bool:
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"
Expand Down

0 comments on commit a57bbee

Please sign in to comment.