From 3500ab400dad9bd8409be41c1c75eeaac5819ba1 Mon Sep 17 00:00:00 2001 From: Milan Zhou Date: Thu, 8 Feb 2024 14:20:25 -0800 Subject: [PATCH] [TorchAudio][stream_reader] Make StreamingMediaDecoderBytes available for C++ usage (#3742) Summary: This diff moves `StreamingMediaDecoderBytes` from the `pybind` file to the `stream_reader` files, enabling usage for all C++ use cases. We can similarly migrate the `StreamingMediaDecoderFile` and `StreamingMediaEncoderFile` in the future. Interestingly we do not have a `StreamingMediaEncoderBytes` implementation. Differential Revision: D53585768 --- src/libtorio/ffmpeg/ffmpeg.h | 6 ++ src/libtorio/ffmpeg/pybind/pybind.cpp | 56 ------------------- .../ffmpeg/stream_reader/stream_reader.cpp | 49 ++++++++++++++++ .../ffmpeg/stream_reader/stream_reader.h | 22 ++++++++ 4 files changed, 77 insertions(+), 56 deletions(-) diff --git a/src/libtorio/ffmpeg/ffmpeg.h b/src/libtorio/ffmpeg/ffmpeg.h index 1d5ed152537..bc03feb85dc 100644 --- a/src/libtorio/ffmpeg/ffmpeg.h +++ b/src/libtorio/ffmpeg/ffmpeg.h @@ -208,6 +208,12 @@ struct StreamParams { AVRational time_base{}; int stream_index{}; }; + +struct BytesWrapper { + std::string_view src; + size_t index = 0; +}; + } // namespace io } // namespace torio diff --git a/src/libtorio/ffmpeg/pybind/pybind.cpp b/src/libtorio/ffmpeg/pybind/pybind.cpp index 47738a4b4b8..6693fa308bf 100644 --- a/src/libtorio/ffmpeg/pybind/pybind.cpp +++ b/src/libtorio/ffmpeg/pybind/pybind.cpp @@ -188,62 +188,6 @@ struct StreamingMediaEncoderFileObj : private FileObj, py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {} }; -////////////////////////////////////////////////////////////////////////////// -// StreamingMediaDecoder/Encoder Bytes -////////////////////////////////////////////////////////////////////////////// -struct BytesWrapper { - std::string_view src; - size_t index = 0; -}; - -static int read_bytes(void* opaque, uint8_t* buf, int buf_size) { - BytesWrapper* wrapper = static_cast(opaque); - - auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size); - if (num_read == 0) { - return AVERROR_EOF; - } - auto head = wrapper->src.data() + wrapper->index; - memcpy(buf, head, num_read); - wrapper->index += num_read; - return num_read; -} - -static int64_t seek_bytes(void* opaque, int64_t offset, int whence) { - BytesWrapper* wrapper = static_cast(opaque); - if (whence == AVSEEK_SIZE) { - return wrapper->src.size(); - } - - if (whence == SEEK_SET) { - wrapper->index = offset; - } else if (whence == SEEK_CUR) { - wrapper->index += offset; - } else if (whence == SEEK_END) { - wrapper->index = wrapper->src.size() + offset; - } else { - TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence); - } - return static_cast(wrapper->index); -} - -struct StreamingMediaDecoderBytes : private BytesWrapper, - public StreamingMediaDecoderCustomIO { - StreamingMediaDecoderBytes( - std::string_view src, - const c10::optional& format, - const c10::optional>& option, - int64_t buffer_size) - : BytesWrapper{src}, - StreamingMediaDecoderCustomIO( - this, - format, - buffer_size, - read_bytes, - seek_bytes, - option) {} -}; - #ifndef TORIO_FFMPEG_EXT_NAME #error TORIO_FFMPEG_EXT_NAME must be defined. #endif diff --git a/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp b/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp index cb85df0c96d..0e96c2856df 100644 --- a/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp +++ b/src/libtorio/ffmpeg/stream_reader/stream_reader.cpp @@ -610,4 +610,53 @@ StreamingMediaDecoderCustomIO::StreamingMediaDecoderCustomIO( : CustomInput(opaque, buffer_size, read_packet, seek), StreamingMediaDecoder(io_ctx, format, option) {} +namespace { +static int read_bytes(void* opaque, uint8_t* buf, int buf_size) { + BytesWrapper* wrapper = static_cast(opaque); + + auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size); + if (num_read == 0) { + return AVERROR_EOF; + } + auto head = wrapper->src.data() + wrapper->index; + memcpy(buf, head, num_read); + wrapper->index += num_read; + return num_read; +} + +static int64_t seek_bytes(void* opaque, int64_t offset, int whence) { + BytesWrapper* wrapper = static_cast(opaque); + if (whence == AVSEEK_SIZE) { + return wrapper->src.size(); + } + + if (whence == SEEK_SET) { + wrapper->index = offset; + } else if (whence == SEEK_CUR) { + wrapper->index += offset; + } else if (whence == SEEK_END) { + wrapper->index = wrapper->src.size() + offset; + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence); + } + return static_cast(wrapper->index); +} +} // namespace + +////////////////////////////////////////////////////////////////////////////// +// StreamingMediaDecoder Bytes +////////////////////////////////////////////////////////////////////////////// +StreamingMediaDecoderBytes::StreamingMediaDecoderBytes( + std::string_view src, + const c10::optional& format, + const c10::optional>& option, + int64_t buffer_size) + : BytesWrapper{src}, + StreamingMediaDecoderCustomIO( + this, + format, + buffer_size, + read_bytes, + seek_bytes, + option) {} } // namespace torio::io diff --git a/src/libtorio/ffmpeg/stream_reader/stream_reader.h b/src/libtorio/ffmpeg/stream_reader/stream_reader.h index 9d910ff0015..54fb4c49ec0 100644 --- a/src/libtorio/ffmpeg/stream_reader/stream_reader.h +++ b/src/libtorio/ffmpeg/stream_reader/stream_reader.h @@ -386,6 +386,28 @@ class StreamingMediaDecoderCustomIO : private detail::CustomInput, const c10::optional& option = c10::nullopt); }; +////////////////////////////////////////////////////////////////////////////// +// StreamingMediaDecoder Bytes +////////////////////////////////////////////////////////////////////////////// +struct StreamingMediaDecoderBytes : private BytesWrapper, + public StreamingMediaDecoderCustomIO { + public: + /// + /// Construct StreamingMediaDecoder with read and seek functions that read + /// from in memory buffer + /// + /// @param src In memory bytes buffer + /// @param format Specify input format. + /// @param option Custom option passed when initializing format context. + /// @param buffer_size The size of the intermediate buffer, which FFmpeg uses + /// to pass data to function read_packet. + StreamingMediaDecoderBytes( + std::string_view src, + const c10::optional& format, + const c10::optional>& option, + int64_t buffer_size); +}; + // For BC using StreamReader = StreamingMediaDecoder; using StreamReaderCustomIO = StreamingMediaDecoderCustomIO;