Skip to content

Commit

Permalink
[torchcodec] stream_index should not be optional for getting frame by…
Browse files Browse the repository at this point in the history
… index

Differential Revision: D58505608
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 13, 2024
1 parent daac820 commit 067e18b
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("get_next_frame(Tensor(a!) decoder) -> Tensor");
m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int? stream_index=None) -> Tensor");
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int? stream_index=None) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
}

Expand Down Expand Up @@ -131,22 +131,20 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) {
at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
std::optional<int64_t> stream_index) {
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result =
videoDecoder->getFrameAtIndex(stream_index.value_or(-1), frame_index);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return result.frame;
}

at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
std::optional<int64_t> stream_index) {
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndexes(
stream_index.value_or(-1), frameIndicesVec);
auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec);
return result.frames;
}

Expand Down

0 comments on commit 067e18b

Please sign in to comment.