Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchcodec] new core function, get_frames_in_range #47

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,30 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
return output;
}

torch::Tensor VideoDecoder::getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata) {
if (options.shape == "NHWC") {
return torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
return torch::empty(
{numFrames,
3,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width)},
{torch::kUInt8});
} else {
// TODO: should this be a TORCH macro of some kind?
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameDisplayedAtTimestamp(
double seconds) {
for (auto& [streamIndex, stream] : streams_) {
Expand Down Expand Up @@ -778,26 +802,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
throw std::runtime_error(
"Invalid stream index=" + std::to_string(streamIndex));
}

BatchDecodedOutput output;
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& options = streams_[streamIndex].options;
if (options.shape == "NHWC") {
output.frames = torch::empty(
{(long)frameIndexes.size(),
options.height.value_or(*streamMetadata.height),
options.width.value_or(*streamMetadata.width),
3},
{torch::kUInt8});
} else if (options.shape == "NCHW") {
output.frames = torch::empty(
{(long)frameIndexes.size(),
3,
options.height.value_or(*streamMetadata.height),
options.width.value_or(*streamMetadata.width)},
{torch::kUInt8});
} else {
throw std::runtime_error("Unsupported frame shape=" + options.shape);
}
output.frames =
getEmptyTensorForBatch(frameIndexes.size(), options, streamMetadata);

int i = 0;
if (streams_.count(streamIndex) == 0) {
throw std::runtime_error(
Expand All @@ -817,6 +828,47 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndexes(
return output;
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
int streamIndex,
int64_t start,
int64_t stop,
int64_t step) {
TORCH_CHECK(
streamIndex >= 0 || streamIndex < containerMetadata_.streams.size(),
"Invalid stream index=" + std::to_string(streamIndex));
TORCH_CHECK(
streams_.count(streamIndex) > 0,
"Invalid stream index=" + std::to_string(streamIndex));

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];
TORCH_CHECK(
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
TORCH_CHECK(
stop <= stream.allFrames.size(),
"Range stop, " + std::to_string(stop) +
", is more than the number of frames, " +
std::to_string(stream.allFrames.size()));
TORCH_CHECK(
step > 0, "Step must be greater than 0; is " + std::to_string(step));

int64_t numOutputFrames = std::ceil((stop - start) / double(step));
const auto& options = stream.options;
BatchDecodedOutput output;
output.frames =
getEmptyTensorForBatch(numOutputFrames, options, streamMetadata);

int64_t f = 0;
for (int64_t i = start; i < stop; i += step) {
int64_t pts = stream.allFrames[i].pts;
setCursorPtsInSeconds(1.0 * pts / stream.timeBase.den);
torch::Tensor frame = getNextDecodedOutput().frame;
output.frames[f++] = frame;
}

return output;
}

VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutput() {
return getDecodedOutputWithFilter(
[this](int frameStreamIndex, AVFrame* frame) {
Expand Down
11 changes: 11 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ class VideoDecoder {
BatchDecodedOutput getFramesAtIndexes(
int streamIndex,
const std::vector<int64_t>& frameIndexes);
// Returns frames within a given range for a given stream as a single stacked
// Tensor. The range is defined by [start, stop). The values retrieved from
// the range are:
// [start, start+step, start+(2*step), start+(3*step), ..., stop)
// The default for step is 1.
BatchDecodedOutput
getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step);

// --------------------------------------------------------------------------
// DECODER PERFORMANCE STATISTICS API
Expand Down Expand Up @@ -273,6 +280,10 @@ class VideoDecoder {
DecodedOutput convertAVFrameToDecodedOutput(
int streamIndex,
UniqueAVFrame frame);
torch::Tensor getEmptyTensorForBatch(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand Down
15 changes: 15 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
}

Expand Down Expand Up @@ -148,6 +150,18 @@ at::Tensor get_frames_at_indices(
return result.frames;
}

at::Tensor get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result = videoDecoder->getFramesInRange(
stream_index, start, stop, step.value_or(1));
return result.frames;
}

std::string quoteValue(const std::string& value) {
return "\"" + value + "\"";
}
Expand Down Expand Up @@ -249,6 +263,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("get_frame_at_pts", &get_frame_at_pts);
m.impl("get_frame_at_index", &get_frame_at_index);
m.impl("get_frames_at_indices", &get_frames_at_indices);
m.impl("get_frames_in_range", &get_frames_in_range);
}

} // namespace facebook::torchcodec
13 changes: 11 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,23 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds);
at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
std::optional<int64_t> stream_index = std::nullopt);
int64_t stream_index);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
std::optional<int64_t> stream_index = std::nullopt);
int64_t stream_index);

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
at::Tensor get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt);

// Get the next frame from the video as a tensor.
at::Tensor get_next_frame(at::Tensor& decoder);
Expand Down
14 changes: 14 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def load_torchcodec_extension():
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default


Expand Down Expand Up @@ -131,6 +132,19 @@ def get_frames_at_indices_abstract(
return torch.empty(image_size)


@register_fake("torchcodec_ns::get_frames_in_range")
def get_frames_in_range_abstract(
decoder: torch.Tensor,
*,
stream_index: int,
start: int,
stop: int,
step: Optional[int] = None,
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)


@register_fake("torchcodec_ns::get_json_metadata")
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")
33 changes: 0 additions & 33 deletions test/decoders/generate_reference_resources.sh

This file was deleted.

59 changes: 59 additions & 0 deletions test/decoders/video_decoder_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_frame_at_index,
get_frame_at_pts,
get_frames_at_indices,
get_frames_in_range,
get_json_metadata,
get_next_frame,
seek_to_pts,
Expand Down Expand Up @@ -103,6 +104,64 @@ def test_get_frames_at_indices(self):
assert_equal(frames1and6[0], reference_frame1)
assert_equal(frames1and6[1], reference_frame6)

def test_get_frames_in_range(self):
decoder = create_from_file(str(get_reference_video_path()))
add_video_stream(decoder)

ref_frames0_9 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt")
for i in range(0, 9)
]
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")

# ensure that the degenerate case of a range of size 1 works
bulk_frame0 = get_frames_in_range(decoder, stream_index=3, start=0, stop=1)
assert_equal(bulk_frame0[0], ref_frames0_9[0])

bulk_frame1 = get_frames_in_range(decoder, stream_index=3, start=1, stop=2)
assert_equal(bulk_frame1[0], ref_frames0_9[1])

bulk_frame180 = get_frames_in_range(
decoder, stream_index=3, start=180, stop=181
)
assert_equal(bulk_frame180[0], ref_frame180)

bulk_frame_last = get_frames_in_range(
decoder, stream_index=3, start=389, stop=390
)
assert_equal(bulk_frame_last[0], ref_frame_last)

# contiguous ranges
bulk_frames0_9 = get_frames_in_range(decoder, stream_index=3, start=0, stop=9)
for i in range(0, 9):
assert_equal(ref_frames0_9[i], bulk_frames0_9[i])

bulk_frames4_8 = get_frames_in_range(decoder, stream_index=3, start=4, stop=8)
for i, bulk_frame in enumerate(bulk_frames4_8):
assert_equal(ref_frames0_9[i + 4], bulk_frame)

# ranges with a stride
ref_frames15_35 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(15, 36, 5)
]
bulk_frames15_35 = get_frames_in_range(
decoder, stream_index=3, start=15, stop=36, step=5
)
for i, bulk_frame in enumerate(bulk_frames15_35):
assert_equal(ref_frames15_35[i], bulk_frame)

bulk_frames0_9_2 = get_frames_in_range(
decoder, stream_index=3, start=0, stop=9, step=2
)
for i, bulk_frame in enumerate(bulk_frames0_9_2):
assert_equal(ref_frames0_9[i * 2], bulk_frame)

# an empty range is valid!
empty_frame = get_frames_in_range(decoder, stream_index=3, start=5, stop=5)
assert_equal(empty_frame, torch.empty((0, 270, 480, 3), dtype=torch.uint8))

def test_throws_exception_at_eof(self):
decoder = create_from_file(str(get_reference_video_path()))
add_video_stream(decoder)
Expand Down
34 changes: 34 additions & 0 deletions test/generate_reference_resources.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

# Run this script to update the resources used in unit tests. The resources are all derived
# from source media already checked into the repo.

# Fail loudly on errors.
set -x
set -e

TORCHCODEC_PATH=$HOME/fbsource/fbcode/pytorch/torchcodec
RESOURCES_DIR=$TORCHCODEC_PATH/test/resources
VIDEO_PATH=$RESOURCES_DIR/nasa_13013.mp4

# Important note: I used ffmpeg version 6.1.1 to generate these images. We
# must have the version that matches the one that we link against in the test.
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,0)+eq(n\,1)+eq(n\,2)+eq(n\,3)+eq(n\,4)+eq(n\,5)+eq(n\,6)+eq(n\,7)+eq(n\,8)+eq(n\,9)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame%06d.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,15)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000015.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,20)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000020.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,25)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000025.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,30)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000030.bmp"
ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,35)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000035.bmp"
ffmpeg -y -ss 6.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.000000.bmp"
ffmpeg -y -ss 6.1 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.100000.bmp"
ffmpeg -y -ss 10.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time10.000000.bmp"
# This is the last frame of this video.
ffmpeg -y -ss 12.979633 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time12.979633.bmp"
# Audio generation in the form of an mp3.
ffmpeg -y -i "$VIDEO_PATH" -b:a 192K -vn "$VIDEO_PATH.audio.mp3"

for bmp in "$RESOURCES_DIR"/*.bmp
do
python3 convert_image_to_tensor.py "$bmp"
rm -f "$bmp"
done
Binary file modified test/resources/nasa_13013.mp4.audio.mp3
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000003.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000004.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000005.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000006.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000007.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000008.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000009.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000010.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000015.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000020.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000025.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000030.pt
Binary file not shown.
Binary file added test/resources/nasa_13013.mp4.frame000035.pt
Binary file not shown.
Loading