diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 89527244..073c21aa 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -43,10 +43,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()"); } -// ============================== -// Implementations for the operators -// ============================== - +namespace { at::Tensor wrapDecoderPointerToTensor( std::unique_ptr uniqueDecoder) { VideoDecoder* decoder = uniqueDecoder.release(); @@ -59,6 +56,18 @@ at::Tensor wrapDecoderPointerToTensor( return tensor; } +VideoDecoder* unwrapTensorToGetDecoder(at::Tensor& tensor) { + TORCH_INTERNAL_ASSERT(tensor.is_contiguous()); + void* buffer = tensor.mutable_data_ptr(); + VideoDecoder* decoder = static_cast(buffer); + return decoder; +} +} // namespace + +// ============================== +// Implementations for the operators +// ============================== + at::Tensor create_from_file(c10::string_view filename) { std::string filenameStr(filename); std::unique_ptr uniqueDecoder = @@ -99,7 +108,7 @@ void add_video_stream( options.shape = stdShape; } - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStreamDecoder(stream_index.value_or(-1), options); } @@ -109,7 +118,7 @@ void seek_to_pts(at::Tensor& decoder, double seconds) { } at::Tensor get_next_frame(at::Tensor& decoder) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getNextDecodedOutput().frame; if (result.sizes().size() != 3) { throw std::runtime_error( @@ -120,7 +129,7 @@ at::Tensor get_next_frame(at::Tensor& decoder) { } at::Tensor get_frame_at_pts(at::Tensor& decoder, double seconds) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameDisplayedAtTimestamp(seconds); return result.frame; } @@ -129,7 +138,7 @@ at::Tensor get_frame_at_index( at::Tensor& decoder, int64_t stream_index, int64_t frame_index) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); return result.frame; } @@ -148,7 +157,7 @@ at::Tensor get_frames_at_indices( at::Tensor& decoder, int64_t stream_index, at::IntArrayRef frame_indices) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); auto result = videoDecoder->getFramesAtIndexes(stream_index, frameIndicesVec); @@ -161,7 +170,7 @@ at::Tensor get_frames_in_range( int64_t start, int64_t stop, std::optional step = std::nullopt) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto result = videoDecoder->getFramesInRange( stream_index, start, stop, step.value_or(1)); return result.frames; @@ -195,7 +204,7 @@ std::string mapToJson(const std::map& metadataMap) { } std::string get_json_metadata(at::Tensor& decoder) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::ContainerMetadata videoMetadata = videoDecoder->getContainerMetadata(); @@ -261,7 +270,7 @@ std::string get_json_metadata(at::Tensor& decoder) { } std::string get_container_json_metadata(at::Tensor& decoder) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto containerMetadata = videoDecoder->getContainerMetadata(); @@ -292,7 +301,7 @@ std::string get_container_json_metadata(at::Tensor& decoder) { std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); auto streams = videoDecoder->getContainerMetadata().streams; if (stream_index < 0 || stream_index >= streams.size()) { throw std::out_of_range( @@ -365,7 +374,7 @@ std::string _get_json_ffmpeg_library_versions() { } void scan_all_streams_to_update_metadata(at::Tensor& decoder) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->scanFileAndUpdateMetadataAndIndex(); }