Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose metadata as dataclasses - first episode #52

Closed
wants to merge 13 commits into from
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files
args: ['--maxkb=1000']

# - repo: https://github.com/omnilib/ufmt
# rev: v2.6.0
# hooks:
# - id: ufmt
# additional_dependencies:
# - black == 24.4.2
# - usort == 1.0.5
# - repo: https://github.com/omnilib/ufmt
# rev: v2.6.0
# hooks:
# - id: ufmt
# additional_dependencies:
# - black == 24.4.2
# - usort == 1.0.5

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
Expand Down
119 changes: 105 additions & 14 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ namespace facebook::torchcodec {
// ==============================
// Define the operators
// ==============================

// All instances of accepting the decoder as a tensor must be annotated with
// `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being
// mutated in place. We need it to make sure that torch.compile does not reorder
Expand All @@ -35,6 +34,9 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
m.def("_get_json_ffmpeg_library_versions() -> str");
}

Expand Down Expand Up @@ -159,6 +161,29 @@ std::string quoteValue(const std::string& value) {
return "\"" + value + "\"";
}

// TODO: we should use a more robust way to serialize the metadata. There are a
// few alternatives, but ultimately we are limited to what custom ops allow us
// to return. Current ideas are to use a proper JSON library, or to pack all the
// info into tensors. *If* we're OK to drop the export support for metadata, we
// could also easily bind the C++ structs to Python with pybind11.
std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
std::stringstream ss;
ss << "{\n";
auto it = metadataMap.begin();
while (it != metadataMap.end()) {
ss << "\"" << it->first << "\": " << it->second;
++it;
if (it != metadataMap.end()) {
ss << ",\n";
} else {
ss << "\n";
}
}
ss << "}";

return ss.str();
}

std::string get_json_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());

Expand Down Expand Up @@ -222,21 +247,85 @@ std::string get_json_metadata(at::Tensor& decoder) {
std::to_string(*videoMetadata.bestAudioStreamIndex);
}

std::stringstream ss;
ss << "{\n";
auto it = metadataMap.begin();
while (it != metadataMap.end()) {
ss << "\"" << it->first << "\": " << it->second;
++it;
if (it != metadataMap.end()) {
ss << ",\n";
} else {
ss << "\n";
}
return mapToJson(metadataMap);
}

std::string get_container_json_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());

auto containerMetadata = videoDecoder->getContainerMetadata();

std::map<std::string, std::string> map;

if (containerMetadata.durationSeconds.has_value()) {
map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds);
}
ss << "}";

return ss.str();
if (containerMetadata.bitRate.has_value()) {
map["bitRate"] = std::to_string(*containerMetadata.bitRate);
}

if (containerMetadata.bestVideoStreamIndex.has_value()) {
map["bestVideoStreamIndex"] =
std::to_string(*containerMetadata.bestVideoStreamIndex);
}
if (containerMetadata.bestAudioStreamIndex.has_value()) {
map["bestAudioStreamIndex"] =
std::to_string(*containerMetadata.bestAudioStreamIndex);
}

map["numStreams"] = std::to_string(containerMetadata.streams.size());

return mapToJson(map);
}

std::string get_stream_json_metadata(
at::Tensor& decoder,
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());
auto streams = videoDecoder->getContainerMetadata().streams;
if (stream_index < 0 || stream_index >= streams.size()) {
throw std::out_of_range(
"stream_index out of bounds: " + std::to_string(stream_index));
}
auto streamMetadata = streams[stream_index];

std::map<std::string, std::string> map;

if (streamMetadata.durationSeconds.has_value()) {
map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds);
}
if (streamMetadata.bitRate.has_value()) {
map["bitRate"] = std::to_string(*streamMetadata.bitRate);
}
if (streamMetadata.numFramesFromScan.has_value()) {
map["numFramesFromScan"] =
std::to_string(*streamMetadata.numFramesFromScan);
}
if (streamMetadata.numFrames.has_value()) {
map["numFrames"] = std::to_string(*streamMetadata.numFrames);
}
if (streamMetadata.minPtsSecondsFromScan.has_value()) {
map["minPtsSecondsFromScan"] =
std::to_string(*streamMetadata.minPtsSecondsFromScan);
}
if (streamMetadata.maxPtsSecondsFromScan.has_value()) {
map["maxPtsSecondsFromScan"] =
std::to_string(*streamMetadata.maxPtsSecondsFromScan);
}
if (streamMetadata.codecName.has_value()) {
map["codec"] = quoteValue(streamMetadata.codecName.value());
}
if (streamMetadata.width.has_value()) {
map["width"] = std::to_string(*streamMetadata.width);
}
if (streamMetadata.height.has_value()) {
map["height"] = std::to_string(*streamMetadata.height);
}
if (streamMetadata.averageFps.has_value()) {
map["averageFps"] = std::to_string(*streamMetadata.averageFps);
}
return mapToJson(map);
}

std::string _get_json_ffmpeg_library_versions() {
Expand Down Expand Up @@ -277,6 +366,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("add_video_stream", &add_video_stream);
m.impl("get_next_frame", &get_next_frame);
m.impl("get_json_metadata", &get_json_metadata);
m.impl("get_container_json_metadata", &get_container_json_metadata);
m.impl("get_stream_json_metadata", &get_stream_json_metadata);
m.impl("get_frame_at_pts", &get_frame_at_pts);
m.impl("get_frame_at_index", &get_frame_at_index);
m.impl("get_frames_at_indices", &get_frames_at_indices);
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ at::Tensor get_next_frame(at::Tensor& decoder);
// Get the metadata from the video as a string.
std::string get_json_metadata(at::Tensor& decoder);

// Get the container metadata as a string.
std::string get_container_json_metadata(at::Tensor& decoder);

// Get the stream metadata as a string.
std::string get_stream_json_metadata(at::Tensor& decoder);

// Returns version information about the various FFMPEG libraries that are
// loaded in the program's address space.
std::string _get_json_ffmpeg_library_versions();
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
# TODO: Don't use import *

from .video_decoder_ops import * # noqa

from ._metadata import get_video_metadata, StreamMetadata, VideoMetadata
92 changes: 92 additions & 0 deletions src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import json

from dataclasses import dataclass
from typing import List, Optional

import torch

from torchcodec.decoders._core.video_decoder_ops import (
_get_container_json_metadata,
_get_stream_json_metadata,
)


@dataclass
class StreamMetadata:
duration_seconds: Optional[float]
bit_rate: Optional[float]
# TODO: Before release, we should come up with names that better convey the
# " 'fast and potentially inaccurate' vs 'slower but accurate' " tradeoff.
num_frames_retrieved: Optional[int]
num_frames_computed: Optional[int]
min_pts_seconds: Optional[float]
max_pts_seconds: Optional[float]
codec: Optional[str]
width: Optional[int]
height: Optional[int]
average_fps: Optional[float]
stream_index: int

@property
def num_frames(self) -> Optional[int]:
if self.num_frames_computed is not None:
return self.num_frames_computed
else:
return self.num_frames_retrieved


# This may be renamed into e.g. ContainerMetadata in the future to be more generic.
@dataclass
class VideoMetadata:
duration_seconds_container: Optional[float]
bit_rate_container: Optional[float]
best_video_stream_index: Optional[int]
best_audio_stream_index: Optional[int]

streams: List[StreamMetadata]

@property
def duration_seconds(self) -> Optional[float]:
raise NotImplementedError("TODO: decide on logic and implement this!")

@property
def bit_rate(self) -> Optional[float]:
raise NotImplementedError("TODO: decide on logic and implement this!")

@property
def best_video_stream(self) -> StreamMetadata:
if self.best_video_stream_index is None:
raise ValueError("The best video stream is unknown.")
return self.streams[self.best_video_stream_index]


def get_video_metadata(decoder: torch.tensor) -> VideoMetadata:

container_dict = json.loads(_get_container_json_metadata(decoder))
streams_metadata = []
for stream_index in range(container_dict["numStreams"]):
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
streams_metadata.append(
StreamMetadata(
duration_seconds=stream_dict.get("durationSeconds"),
bit_rate=stream_dict.get("bitRate"),
# TODO: We should align the C++ names and the json keys with the Python names
num_frames_retrieved=stream_dict.get("numFrames"),
num_frames_computed=stream_dict.get("numFramesFromScan"),
min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"),
max_pts_seconds=stream_dict.get("maxPtsSecondsFromScan"),
codec=stream_dict.get("codec"),
width=stream_dict.get("width"),
height=stream_dict.get("height"),
average_fps=stream_dict.get("averageFps"),
stream_index=stream_index,
)
)

return VideoMetadata(
duration_seconds_container=container_dict.get("durationSeconds"),
bit_rate_container=container_dict.get("bitRate"),
best_video_stream_index=container_dict.get("bestVideoStreamIndex"),
best_audio_stream_index=container_dict.get("bestAudioStreamIndex"),
streams=streams_metadata,
)
14 changes: 14 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def load_torchcodec_extension():
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
_get_container_json_metadata = (
torch.ops.torchcodec_ns.get_container_json_metadata.default
)
_get_stream_json_metadata = torch.ops.torchcodec_ns.get_stream_json_metadata.default
_get_json_ffmpeg_library_versions = (
torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default
)
Expand Down Expand Up @@ -154,6 +158,16 @@ def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")


@register_fake("torchcodec_ns::get_container_json_metadata")
def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")


@register_fake("torchcodec_ns::get_stream_json_metadata")
def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str:
return torch.empty_like("")


@register_fake("torchcodec_ns::_get_json_ffmpeg_library_versions")
def _get_json_ffmpeg_library_versions_abstract() -> str:
return torch.empty_like("")
Expand Down
32 changes: 25 additions & 7 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Union

import torch

from torchcodec.decoders import _core as core


Expand All @@ -24,12 +24,9 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]):

core.add_video_stream(self._decoder)

# TODO: We should either implement specific core library function to
# retrieve these values, or replace this with a non-JSON metadata
# retrieval.
metadata_json = json.loads(core.get_json_metadata(self._decoder))
self._num_frames = metadata_json["numFrames"]
self._stream_index = metadata_json["bestVideoStreamIndex"]
self.stream_metadata = _get_and_validate_stream_metadata(self._decoder)
self._num_frames = self.stream_metadata.num_frames_computed
self._stream_index = self.stream_metadata.stream_index

def __len__(self) -> int:
return self._num_frames
Expand Down Expand Up @@ -80,3 +77,24 @@ def __next__(self) -> torch.Tensor:
return core.get_next_frame(self._decoder)
except RuntimeError:
raise StopIteration()


def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata:
video_metadata = core.get_video_metadata(decoder)

if video_metadata.best_video_stream_index is None:
raise ValueError(
"The best video stream is unknown. This should never happen. "
"Please report an issue following the steps in <TODO>"
)

best_stream_metadata = video_metadata.streams[
video_metadata.best_video_stream_index
]
if best_stream_metadata.num_frames_computed is None:
raise ValueError(
"The number of frames is unknown. This should never happen. "
"Please report an issue following the steps in <TODO>"
)

return best_stream_metadata
Loading
Loading