diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f8668af6..37762a68 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -723,6 +723,30 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( return output; } +torch::Tensor VideoDecoder::getEmptyTensorForBatch( + int64_t numFrames, + const VideoStreamDecoderOptions& options, + const StreamMetadata& metadata) { + if (options.shape == "NHWC") { + return torch::empty( + {numFrames, + options.height.value_or(*metadata.height), + options.width.value_or(*metadata.width), + 3}, + {torch::kUInt8}); + } else if (options.shape == "NCHW") { + return torch::empty( + {numFrames, + 3, + options.height.value_or(*metadata.height), + options.width.value_or(*metadata.width)}, + {torch::kUInt8}); + } else { + // TODO: should this be a TORCH macro of some kind? + throw std::runtime_error("Unsupported frame shape=" + options.shape); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp( double seconds) { for (auto& [streamIndex, stream] : streams_) { @@ -778,26 +802,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes( throw std::runtime_error( "Invalid stream index=" + std::to_string(streamIndex)); } + BatchDecodedOutput output; const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& options = streams_[streamIndex].options; - if (options.shape == "NHWC") { - output.frames = torch::empty( - {(long)frameIndexes.size(), - options.height.value_or(*streamMetadata.height), - options.width.value_or(*streamMetadata.width), - 3}, - {torch::kUInt8}); - } else if (options.shape == "NCHW") { - output.frames = torch::empty( - {(long)frameIndexes.size(), - 3, - options.height.value_or(*streamMetadata.height), - options.width.value_or(*streamMetadata.width)}, - {torch::kUInt8}); - } else { - throw std::runtime_error("Unsupported frame shape=" + options.shape); - } + output.frames = + getEmptyTensorForBatch(frameIndexes.size(), options, streamMetadata); + int i = 0; if (streams_.count(streamIndex) == 0) { throw std::runtime_error( @@ -817,6 +828,47 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes( return output; } +VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( + int streamIndex, + int64_t start, + int64_t stop, + int64_t step) { + TORCH_CHECK( + streamIndex >= 0 || streamIndex < containerMetadata_.streams.size(), + "Invalid stream index=" + std::to_string(streamIndex)); + TORCH_CHECK( + streams_.count(streamIndex) > 0, + "Invalid stream index=" + std::to_string(streamIndex)); + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& stream = streams_[streamIndex]; + TORCH_CHECK( + start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); + TORCH_CHECK( + stop <= stream.allFrames.size(), + "Range stop, " + std::to_string(stop) + + ", is more than the number of frames, " + + std::to_string(stream.allFrames.size())); + TORCH_CHECK( + step > 0, "Step must be greater than 0; is " + std::to_string(step)); + + int64_t numOutputFrames = std::ceil((stop - start) / double(step)); + const auto& options = stream.options; + BatchDecodedOutput output; + output.frames = + getEmptyTensorForBatch(numOutputFrames, options, streamMetadata); + + int64_t f = 0; + for (int64_t i = start; i < stop; i += step) { + int64_t pts = stream.allFrames[i].pts; + setCursorPtsInSeconds(1.0 * pts / stream.timeBase.den); + torch::Tensor frame = getNextDecodedOutput().frame; + output.frames[f++] = frame; + } + + return output; +} + VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() { return getDecodedOutputWithFilter( [this](int frameStreamIndex, AVFrame* frame) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 39cd93b7..822374b2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -183,6 +183,13 @@ class VideoDecoder { BatchDecodedOutput getFramesAtIndexes( int streamIndex, const std::vector& frameIndexes); + // Returns frames within a given range for a given stream as a single stacked + // Tensor. The range is defined by [start, stop). The values retrieved from + // the range are: + // [start, start+step, start+(2*step), start+(3*step), ..., stop) + // The default for step is 1. + BatchDecodedOutput + getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); // -------------------------------------------------------------------------- // DECODER PERFORMANCE STATISTICS API @@ -273,6 +280,10 @@ class VideoDecoder { DecodedOutput convertAVFrameToDecodedOutput( int streamIndex, UniqueAVFrame frame); + torch::Tensor getEmptyTensorForBatch( + int64_t numFrames, + const VideoStreamDecoderOptions& options, + const StreamMetadata& metadata); DecoderOptions options_; ContainerMetadata containerMetadata_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 013d672f..0d0a4cad 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); + m.def( + "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); } @@ -148,6 +150,18 @@ at::Tensor get_frames_at_indices( return result.frames; } +at::Tensor get_frames_in_range( + at::Tensor& decoder, + int64_t stream_index, + int64_t start, + int64_t stop, + std::optional step = std::nullopt) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto result = videoDecoder->getFramesInRange( + stream_index, start, stop, step.value_or(1)); + return result.frames; +} + std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } @@ -249,6 +263,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); + m.impl("get_frames_in_range", &get_frames_in_range); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index c029473b..fc6c8909 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -45,14 +45,23 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds); at::Tensor get_frame_at_index( at::Tensor& decoder, int64_t frame_index, - std::optional stream_index = std::nullopt); + int64_t stream_index); // Return the frames at a given index for a given stream as a single stacked // Tensor. at::Tensor get_frames_at_indices( at::Tensor& decoder, at::IntArrayRef frame_indices, - std::optional stream_index = std::nullopt); + int64_t stream_index); + +// Return the frames inside a range as a single stacked Tensor. The range is +// defined as [start, stop). +at::Tensor get_frames_in_range( + at::Tensor& decoder, + int64_t stream_index, + int64_t start, + int64_t stop, + std::optional step = std::nullopt); // Get the next frame from the video as a tensor. at::Tensor get_next_frame(at::Tensor& decoder); diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 3aba20e5..369584e5 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -57,6 +57,7 @@ def load_torchcodec_extension(): get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default +get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default @@ -131,6 +132,19 @@ def get_frames_at_indices_abstract( return torch.empty(image_size) +@register_fake("torchcodec_ns::get_frames_in_range") +def get_frames_in_range_abstract( + decoder: torch.Tensor, + *, + stream_index: int, + start: int, + stop: int, + step: Optional[int] = None, +) -> torch.Tensor: + image_size = [get_ctx().new_dynamic_size() for _ in range(4)] + return torch.empty(image_size) + + @register_fake("torchcodec_ns::get_json_metadata") def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return torch.empty_like("") diff --git a/test/decoders/convert_image_to_tensor.py b/test/convert_image_to_tensor.py similarity index 100% rename from test/decoders/convert_image_to_tensor.py rename to test/convert_image_to_tensor.py diff --git a/test/decoders/generate_reference_resources.sh b/test/decoders/generate_reference_resources.sh deleted file mode 100755 index 5d5bb765..00000000 --- a/test/decoders/generate_reference_resources.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash - -# Run this script to update the resources used in unit tests. The resources are all derived -# from source media already checked into the repo. - -# Fail loudly on errors. -set -x -set -e - -# 1. Create a temporary directory to dump bitmaps into. -TEMP_DIR=$(mktemp -d) -echo "Creating all bitmaps in $TEMP_DIR" - -TORCHCODEC_PATH=$HOME/fbsource/fbcode/pytorch/torchcodec -RESOURCES_DIR=$TORCHCODEC_PATH/test/decoders/resources -VIDEO_PATH=$RESOURCES_DIR/nasa_13013.mp4 - -# Important note: I used ffmpeg version 6.1.1 to generate these images. We -# must have the version that matches the one that we link against in the test. -ffmpeg -i "$VIDEO_PATH" -vf select='eq(n\,0)+eq(n\,1)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame%06d.bmp" -ffmpeg -ss 6.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.000000.bmp" -ffmpeg -ss 6.1 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.100000.bmp" -ffmpeg -ss 10.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time10.000000.bmp" -# This is the last frame of this video. -ffmpeg -ss 12.979633 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time12.979633.bmp" -# Audio generation in the form of an mp3. -ffmpeg -i "$VIDEO_PATH" -b:a 192K -vn "$VIDEO_PATH.audio.mp3" - -for bmp in "$RESOURCES_DIR"/*.bmp -do - python convert_image_to_tensor.py "$bmp" - rm -f "$bmp" -done diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index 06c3e2eb..478bc41c 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -17,6 +17,7 @@ get_frame_at_index, get_frame_at_pts, get_frames_at_indices, + get_frames_in_range, get_json_metadata, get_next_frame, seek_to_pts, @@ -103,6 +104,64 @@ def test_get_frames_at_indices(self): assert_equal(frames1and6[0], reference_frame1) assert_equal(frames1and6[1], reference_frame6) + def test_get_frames_in_range(self): + decoder = create_from_file(str(get_reference_video_path())) + add_video_stream(decoder) + + ref_frames0_9 = [ + load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt") + for i in range(0, 9) + ] + ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + + # ensure that the degenerate case of a range of size 1 works + bulk_frame0 = get_frames_in_range(decoder, stream_index=3, start=0, stop=1) + assert_equal(bulk_frame0[0], ref_frames0_9[0]) + + bulk_frame1 = get_frames_in_range(decoder, stream_index=3, start=1, stop=2) + assert_equal(bulk_frame1[0], ref_frames0_9[1]) + + bulk_frame180 = get_frames_in_range( + decoder, stream_index=3, start=180, stop=181 + ) + assert_equal(bulk_frame180[0], ref_frame180) + + bulk_frame_last = get_frames_in_range( + decoder, stream_index=3, start=389, stop=390 + ) + assert_equal(bulk_frame_last[0], ref_frame_last) + + # contiguous ranges + bulk_frames0_9 = get_frames_in_range(decoder, stream_index=3, start=0, stop=9) + for i in range(0, 9): + assert_equal(ref_frames0_9[i], bulk_frames0_9[i]) + + bulk_frames4_8 = get_frames_in_range(decoder, stream_index=3, start=4, stop=8) + for i, bulk_frame in enumerate(bulk_frames4_8): + assert_equal(ref_frames0_9[i + 4], bulk_frame) + + # ranges with a stride + ref_frames15_35 = [ + load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") + for i in range(15, 36, 5) + ] + bulk_frames15_35 = get_frames_in_range( + decoder, stream_index=3, start=15, stop=36, step=5 + ) + for i, bulk_frame in enumerate(bulk_frames15_35): + assert_equal(ref_frames15_35[i], bulk_frame) + + bulk_frames0_9_2 = get_frames_in_range( + decoder, stream_index=3, start=0, stop=9, step=2 + ) + for i, bulk_frame in enumerate(bulk_frames0_9_2): + assert_equal(ref_frames0_9[i * 2], bulk_frame) + + # an empty range is valid! + empty_frame = get_frames_in_range(decoder, stream_index=3, start=5, stop=5) + assert_equal(empty_frame, torch.empty((0, 270, 480, 3), dtype=torch.uint8)) + def test_throws_exception_at_eof(self): decoder = create_from_file(str(get_reference_video_path())) add_video_stream(decoder) diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh new file mode 100755 index 00000000..94524b21 --- /dev/null +++ b/test/generate_reference_resources.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Run this script to update the resources used in unit tests. The resources are all derived +# from source media already checked into the repo. + +# Fail loudly on errors. +set -x +set -e + +TORCHCODEC_PATH=$HOME/fbsource/fbcode/pytorch/torchcodec +RESOURCES_DIR=$TORCHCODEC_PATH/test/resources +VIDEO_PATH=$RESOURCES_DIR/nasa_13013.mp4 + +# Important note: I used ffmpeg version 6.1.1 to generate these images. We +# must have the version that matches the one that we link against in the test. +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,0)+eq(n\,1)+eq(n\,2)+eq(n\,3)+eq(n\,4)+eq(n\,5)+eq(n\,6)+eq(n\,7)+eq(n\,8)+eq(n\,9)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame%06d.bmp" +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,15)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000015.bmp" +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,20)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000020.bmp" +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,25)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000025.bmp" +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,30)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000030.bmp" +ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,35)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000035.bmp" +ffmpeg -y -ss 6.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.000000.bmp" +ffmpeg -y -ss 6.1 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.100000.bmp" +ffmpeg -y -ss 10.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time10.000000.bmp" +# This is the last frame of this video. +ffmpeg -y -ss 12.979633 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time12.979633.bmp" +# Audio generation in the form of an mp3. +ffmpeg -y -i "$VIDEO_PATH" -b:a 192K -vn "$VIDEO_PATH.audio.mp3" + +for bmp in "$RESOURCES_DIR"/*.bmp +do + python3 convert_image_to_tensor.py "$bmp" + rm -f "$bmp" +done diff --git a/test/resources/nasa_13013.mp4.audio.mp3 b/test/resources/nasa_13013.mp4.audio.mp3 index e5b0b687..05681a83 100644 Binary files a/test/resources/nasa_13013.mp4.audio.mp3 and b/test/resources/nasa_13013.mp4.audio.mp3 differ diff --git a/test/resources/nasa_13013.mp4.frame000003.pt b/test/resources/nasa_13013.mp4.frame000003.pt new file mode 100644 index 00000000..2d9a04cd Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000003.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000004.pt b/test/resources/nasa_13013.mp4.frame000004.pt new file mode 100644 index 00000000..1efb87e8 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000004.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000005.pt b/test/resources/nasa_13013.mp4.frame000005.pt new file mode 100644 index 00000000..427c06cc Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000005.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000006.pt b/test/resources/nasa_13013.mp4.frame000006.pt new file mode 100644 index 00000000..a6ceb4ce Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000006.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000007.pt b/test/resources/nasa_13013.mp4.frame000007.pt new file mode 100644 index 00000000..7fd4d482 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000007.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000008.pt b/test/resources/nasa_13013.mp4.frame000008.pt new file mode 100644 index 00000000..cbbdce81 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000008.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000009.pt b/test/resources/nasa_13013.mp4.frame000009.pt new file mode 100644 index 00000000..7f075cf0 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000009.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000010.pt b/test/resources/nasa_13013.mp4.frame000010.pt new file mode 100644 index 00000000..ce926a7f Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000010.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000015.pt b/test/resources/nasa_13013.mp4.frame000015.pt new file mode 100644 index 00000000..2f141b3f Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000015.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000020.pt b/test/resources/nasa_13013.mp4.frame000020.pt new file mode 100644 index 00000000..15e4c530 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000020.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000025.pt b/test/resources/nasa_13013.mp4.frame000025.pt new file mode 100644 index 00000000..e9e07634 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000025.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000030.pt b/test/resources/nasa_13013.mp4.frame000030.pt new file mode 100644 index 00000000..bb3d6eb4 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000030.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000035.pt b/test/resources/nasa_13013.mp4.frame000035.pt new file mode 100644 index 00000000..70191554 Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000035.pt differ