From 067e18b03b46e176129305e43aaee1fa3f93c59d Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 12 Jun 2024 19:30:53 -0700 Subject: [PATCH] [torchcodec] stream_index should not be optional for getting frame by index Differential Revision: D58505608 --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index b2f69ed1..013d672f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -37,9 +37,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("get_next_frame(Tensor(a!) decoder) -> Tensor"); m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor"); m.def( - "get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int? stream_index=None) -> Tensor"); + "get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int? stream_index=None) -> Tensor"); + "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); } @@ -131,22 +131,20 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) { at::Tensor get_frame_at_index( at::Tensor& decoder, int64_t frame_index, - std::optional stream_index) { + int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); - auto result = - videoDecoder->getFrameAtIndex(stream_index.value_or(-1), frame_index); + auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); return result.frame; } at::Tensor get_frames_at_indices( at::Tensor& decoder, at::IntArrayRef frame_indices, - std::optional stream_index) { + int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndexes( - stream_index.value_or(-1), frameIndicesVec); + auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec); return result.frames; }