Skip to content

Commit

Permalink
[torchcodec] Add a function to unwrap the VideoDecoder
Browse files Browse the repository at this point in the history
Summary: This is useful for readability and also in the future if we want to warp a VideoDecoder* instead of a VideoDecoder in the tensor.

Reviewed By: scotts

Differential Revision: D59338247

fbshipit-source-id: 9076cc917a394328fe22e045631c5e079ecb2499
  • Loading branch information
ahmadsharif1 authored and facebook-github-bot committed Jul 3, 2024
1 parent b78faa0 commit b6c624e
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
}

// ==============================
// Implementations for the operators
// ==============================

namespace {
at::Tensor wrapDecoderPointerToTensor(
std::unique_ptr<VideoDecoder> uniqueDecoder) {
VideoDecoder* decoder = uniqueDecoder.release();
Expand All @@ -59,6 +56,18 @@ at::Tensor wrapDecoderPointerToTensor(
return tensor;
}

VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) {
TORCH_INTERNAL_ASSERT(tensor.is_contiguous());
void* buffer = tensor.mutable_data_ptr();
VideoDecoder* decoder = static_cast<VideoDecoder*>(buffer);
return decoder;
}
} // namespace

// ==============================
// Implementations for the operators
// ==============================

at::Tensor create_from_file(c10::string_view filename) {
std::string filenameStr(filename);
std::unique_ptr<VideoDecoder> uniqueDecoder =
Expand Down Expand Up @@ -99,7 +108,7 @@ void add_video_stream(
options.shape = stdShape;
}

auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options);
}

Expand All @@ -109,7 +118,7 @@ void seek_to_pts(at::Tensor& decoder, double seconds) {
}

at::Tensor get_next_frame(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getNextDecodedOutput().frame;
if (result.sizes().size() != 3) {
throw std::runtime_error(
Expand All @@ -120,7 +129,7 @@ at::Tensor get_next_frame(at::Tensor& decoder) {
}

at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds);
return result.frame;
}
Expand All @@ -129,7 +138,7 @@ at::Tensor get_frame_at_index(
at::Tensor& decoder,
int64_t stream_index,
int64_t frame_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index);
return result.frame;
}
Expand All @@ -148,7 +157,7 @@ at::Tensor get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec);
Expand All @@ -161,7 +170,7 @@ at::Tensor get_frames_in_range(
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFramesInRange(
stream_index, start, stop, step.value_or(1));
return result.frames;
Expand Down Expand Up @@ -195,7 +204,7 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
}

std::string get_json_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

VideoDecoder::ContainerMetadata videoMetadata =
videoDecoder->getContainerMetadata();
Expand Down Expand Up @@ -261,7 +270,7 @@ std::string get_json_metadata(at::Tensor& decoder) {
}

std::string get_container_json_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

auto containerMetadata = videoDecoder->getContainerMetadata();

Expand Down Expand Up @@ -292,7 +301,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) {
std::string get_stream_json_metadata(
at::Tensor& decoder,
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto streams = videoDecoder->getContainerMetadata().streams;
if (stream_index < 0 || stream_index >= streams.size()) {
throw std::out_of_range(
Expand Down Expand Up @@ -365,7 +374,7 @@ std::string _get_json_ffmpeg_library_versions() {
}

void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
videoDecoder->scanFileAndUpdateMetadataAndIndex();
}

Expand Down

0 comments on commit b6c624e

Please sign in to comment.