Skip to content

Commit

Permalink
[TorchAudio][stream_reader] Make StreamingMediaDecoderBytes available…
Browse files Browse the repository at this point in the history
… for C++ usage (pytorch#3742)

Summary:

This diff moves `StreamingMediaDecoderBytes` from the `pybind` file to the `stream_reader` files, enabling usage for all C++ use cases.

We can similarly migrate the `StreamingMediaDecoderFile` and `StreamingMediaEncoderFile` in the future. Interestingly we do not have a `StreamingMediaEncoderBytes` implementation.

Differential Revision: D53585768
  • Loading branch information
Milan Zhou authored and facebook-github-bot committed Feb 8, 2024
1 parent 02586da commit 3500ab4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 56 deletions.
6 changes: 6 additions & 0 deletions src/libtorio/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ struct StreamParams {
AVRational time_base{};
int stream_index{};
};

struct BytesWrapper {
std::string_view src;
size_t index = 0;
};

} // namespace io
} // namespace torio

Expand Down
56 changes: 0 additions & 56 deletions src/libtorio/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,62 +188,6 @@ struct StreamingMediaEncoderFileObj : private FileObj,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
};

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder/Encoder Bytes
//////////////////////////////////////////////////////////////////////////////
struct BytesWrapper {
std::string_view src;
size_t index = 0;
};

static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);

auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
auto head = wrapper->src.data() + wrapper->index;
memcpy(buf, head, num_read);
wrapper->index += num_read;
return num_read;
}

static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
if (whence == AVSEEK_SIZE) {
return wrapper->src.size();
}

if (whence == SEEK_SET) {
wrapper->index = offset;
} else if (whence == SEEK_CUR) {
wrapper->index += offset;
} else if (whence == SEEK_END) {
wrapper->index = wrapper->src.size() + offset;
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
}
return static_cast<int64_t>(wrapper->index);
}

struct StreamingMediaDecoderBytes : private BytesWrapper,
public StreamingMediaDecoderCustomIO {
StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: BytesWrapper{src},
StreamingMediaDecoderCustomIO(
this,
format,
buffer_size,
read_bytes,
seek_bytes,
option) {}
};

#ifndef TORIO_FFMPEG_EXT_NAME
#error TORIO_FFMPEG_EXT_NAME must be defined.
#endif
Expand Down
49 changes: 49 additions & 0 deletions src/libtorio/ffmpeg/stream_reader/stream_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,53 @@ StreamingMediaDecoderCustomIO::StreamingMediaDecoderCustomIO(
: CustomInput(opaque, buffer_size, read_packet, seek),
StreamingMediaDecoder(io_ctx, format, option) {}

namespace {
static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);

auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
auto head = wrapper->src.data() + wrapper->index;
memcpy(buf, head, num_read);
wrapper->index += num_read;
return num_read;
}

static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
if (whence == AVSEEK_SIZE) {
return wrapper->src.size();
}

if (whence == SEEK_SET) {
wrapper->index = offset;
} else if (whence == SEEK_CUR) {
wrapper->index += offset;
} else if (whence == SEEK_END) {
wrapper->index = wrapper->src.size() + offset;
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
}
return static_cast<int64_t>(wrapper->index);
}
} // namespace

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder Bytes
//////////////////////////////////////////////////////////////////////////////
StreamingMediaDecoderBytes::StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: BytesWrapper{src},
StreamingMediaDecoderCustomIO(
this,
format,
buffer_size,
read_bytes,
seek_bytes,
option) {}
} // namespace torio::io
22 changes: 22 additions & 0 deletions src/libtorio/ffmpeg/stream_reader/stream_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,28 @@ class StreamingMediaDecoderCustomIO : private detail::CustomInput,
const c10::optional<OptionDict>& option = c10::nullopt);
};

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder Bytes
//////////////////////////////////////////////////////////////////////////////
struct StreamingMediaDecoderBytes : private BytesWrapper,
public StreamingMediaDecoderCustomIO {
public:
///
/// Construct StreamingMediaDecoder with read and seek functions that read
/// from in memory buffer
///
/// @param src In memory bytes buffer
/// @param format Specify input format.
/// @param option Custom option passed when initializing format context.
/// @param buffer_size The size of the intermediate buffer, which FFmpeg uses
/// to pass data to function read_packet.
StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size);
};

// For BC
using StreamReader = StreamingMediaDecoder;
using StreamReaderCustomIO = StreamingMediaDecoderCustomIO;
Expand Down

0 comments on commit 3500ab4

Please sign in to comment.