Skip to content

Commit

Permalink
[torchcodec] unify stream_index placement in core API (#49)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #49

We initially agree on placing `stream_index` last in the core APIs, but that was based on a mistaken understanding that it was optional. Since it is not optional, it makes more sense for it to appear early, always after the decoder. It still remains keyword-only, as it is an integer parameter among many others.

Differential Revision: D58952296
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 24, 2024
1 parent b3ddc5c commit 7abf948
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 16 deletions.
12 changes: 6 additions & 6 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("get_next_frame(Tensor(a!) decoder) -> Tensor");
m.def("get_frame_at_pts(Tensor(a!) decoder, float seconds) -> Tensor");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index, int stream_index) -> Tensor");
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> Tensor");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> Tensor");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
Expand Down Expand Up @@ -132,17 +132,17 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) {

at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
int64_t stream_index) {
int64_t stream_index,
int64_t frame_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return result.frame;
}

at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
int64_t stream_index) {
int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
Expand Down
8 changes: 4 additions & 4 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds);
// Return the frame that is visible at a given index in the video.
at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t frame_index,
int64_t stream_index);
int64_t stream_index,
int64_t frame_index);

// Return the frames at a given index for a given stream as a single stacked
// Tensor.
at::Tensor get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices,
int64_t stream_index);
int64_t stream_index,
at::IntArrayRef frame_indices);

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_frame_at_pts_abstract(decoder: torch.Tensor, seconds: float) -> torch.Te

@register_fake("torchcodec_ns::get_frame_at_index")
def get_frame_at_index_abstract(
decoder: torch.Tensor, *, frame_index: int, stream_index: int
decoder: torch.Tensor, *, stream_index: int, frame_index: int
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(3)]
return torch.empty(image_size)
Expand All @@ -125,8 +125,8 @@ def get_frame_at_index_abstract(
def get_frames_at_indices_abstract(
decoder: torch.Tensor,
*,
frame_indices: List[int],
stream_index: int,
frame_indices: List[int],
) -> torch.Tensor:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return torch.empty(image_size)
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def _get_clips_for_index_based_sampling(
]
frames = get_frames_at_indices(
video_decoder,
frame_indices=batch_indexes,
stream_index=metadata_json["bestVideoStreamIndex"],
frame_indices=batch_indexes,
)
clips.append(frames)

Expand Down
6 changes: 3 additions & 3 deletions test/decoders/video_decoder_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,19 @@ def test_get_frame_at_pts(self):
def test_get_frame_at_index(self):
decoder = create_from_file(str(get_reference_video_path()))
add_video_stream(decoder)
frame1 = get_frame_at_index(decoder, frame_index=0, stream_index=3)
frame1 = get_frame_at_index(decoder, stream_index=3, frame_index=0)
reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
assert_equal(frame1, reference_frame1)
# The frame that is displayed at 6 seconds is frame 180 from a 0-based index.
frame6 = get_frame_at_index(decoder, frame_index=180, stream_index=3)
frame6 = get_frame_at_index(decoder, stream_index=3, frame_index=180)
reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
assert_equal(frame6, reference_frame6)

def test_get_frames_at_indices(self):
decoder = create_from_file(str(get_reference_video_path()))
add_video_stream(decoder)
frames1and6 = get_frames_at_indices(
decoder, frame_indices=[0, 180], stream_index=3
decoder, stream_index=3, frame_indices=[0, 180]
)
reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
Expand Down

0 comments on commit 7abf948

Please sign in to comment.