From fd9bc4026dc3701553f2f174a36babb613226061 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jun 2024 16:06:39 +0100 Subject: [PATCH 01/11] WIP --- .pre-commit-config.yaml | 14 +- src/torchcodec/decoders/_core/VideoDecoder.h | 5 + .../decoders/_core/VideoDecoderOps.cpp | 122 ++++++++++++++---- .../decoders/_core/VideoDecoderOps.h | 6 + src/torchcodec/decoders/_core/__init__.py | 7 + src/torchcodec/decoders/_core/_metadata.py | 108 ++++++++++++++++ .../decoders/_core/video_decoder_ops.py | 14 ++ .../decoders/_simple_video_decoder.py | 54 ++++++-- 8 files changed, 293 insertions(+), 37 deletions(-) create mode 100644 src/torchcodec/decoders/_core/_metadata.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 100eea43..d381a5a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] - # - repo: https://github.com/omnilib/ufmt - # rev: v2.6.0 - # hooks: - # - id: ufmt - # additional_dependencies: - # - black == 24.4.2 - # - usort == 1.0.5 + - repo: https://github.com/omnilib/ufmt + rev: v2.6.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 24.4.2 + - usort == 1.0.5 - repo: https://github.com/PyCQA/flake8 rev: 7.1.0 diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 39cd93b7..8c0033af 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -101,6 +101,11 @@ class VideoDecoder { std::optional height; }; struct ContainerMetadata { + // TODO: in C++ the StreamMetadata vec is part of the ContainerMetadata. In + // Python, the equivalent list isn't part of the containers' metadata: it is + // a separate attribute of the VideoMetaData dataclass, next to the + // container metadata. We can probably align the C++ structure to reflect + // the Python one? std::vector streams; int numAudioStreams = 0; int numVideoStreams = 0; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 013d672f..1b1cd508 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -11,15 +11,6 @@ namespace facebook::torchcodec { // ============================== // Define the operators // ============================== - -torch::Tensor plus_one(torch::Tensor t) { - return t + 1; -} - -TORCH_LIBRARY(plusoneops, m) { - m.def("plus_one", plus_one); -} - // All instances of accepting the decoder as a tensor must be annotated with // `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being // mutated in place. We need it to make sure that torch.compile does not reorder @@ -41,6 +32,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); + m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); + m.def("get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); } // ============================== @@ -152,6 +145,25 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +std::string mapToJson(const std::map& metadataMap) { + + std::stringstream ss; + ss << "{\n"; + auto it = metadataMap.begin(); + while (it != metadataMap.end()) { + ss << "\"" << it->first << "\": " << it->second; + ++it; + if (it != metadataMap.end()) { + ss << ",\n"; + } else { + ss << "\n"; + } + } + ss << "}"; + + return ss.str(); +} + std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); @@ -219,23 +231,87 @@ std::string get_json_metadata(at::Tensor& decoder) { std::to_string(*videoMetadata.bestAudioStreamIndex); } - std::stringstream ss; - ss << "{\n"; - auto it = metadataMap.begin(); - while (it != metadataMap.end()) { - ss << "\"" << it->first << "\": " << it->second; - ++it; - if (it != metadataMap.end()) { - ss << ",\n"; - } else { - ss << "\n"; - } + return mapToJson(metadataMap); +} + +std::string get_container_json_metadata(at::Tensor &decoder) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + + auto containerMetadata = videoDecoder->getContainerMetadata(); + + std::map map; + + if (containerMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds); } - ss << "}"; - return ss.str(); + if (containerMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*containerMetadata.bitRate); + } + + if (containerMetadata.bestVideoStreamIndex.has_value()) { + map["bestVideoStreamIndex"] = + std::to_string(*containerMetadata.bestVideoStreamIndex); + } + if (containerMetadata.bestAudioStreamIndex.has_value()) { + map["bestAudioStreamIndex"] = + std::to_string(*containerMetadata.bestAudioStreamIndex); + } + + // TODO: Q from Nicolas - is there a better way to retrieve and propagate the + // number of streams? + map["numStreams"] = std::to_string(containerMetadata.streams.size()); + + return mapToJson(map); +} + + +std::string get_stream_json_metadata(at::Tensor &decoder, + int64_t stream_index) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto streamMetadata = + videoDecoder->getContainerMetadata().streams[stream_index]; + + std::map map; + + if (streamMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromScan.has_value()) { + map["numFramesFromScan"] = + std::to_string(*streamMetadata.numFramesFromScan); + } + if (streamMetadata.numFrames.has_value()) { + map["numFrames"] = std::to_string(*streamMetadata.numFrames); + } + if (streamMetadata.minPtsSecondsFromScan.has_value()) { + map["minPtsSecondsFromScan"] = + std::to_string(*streamMetadata.minPtsSecondsFromScan); + } + if (streamMetadata.maxPtsSecondsFromScan.has_value()) { + map["maxPtsSecondsFromScan"] = + std::to_string(*streamMetadata.maxPtsSecondsFromScan); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.averageFps.has_value()) { + map["averageFps"] = std::to_string(*streamMetadata.averageFps); + } + return mapToJson(map); } + + TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); @@ -246,6 +322,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("get_next_frame", &get_next_frame); m.impl("get_json_metadata", &get_json_metadata); + m.impl("get_container_json_metadata", &get_container_json_metadata); + m.impl("get_stream_json_metadata", &get_stream_json_metadata); m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index c029473b..ace029a6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -60,4 +60,10 @@ at::Tensor get_next_frame(at::Tensor& decoder); // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); +// Get the container metadata as a string. +std::string get_container_json_metadata(at::Tensor& decoder); + +// Get the stream metadata as a string. +std::string get_stream_json_metadata(at::Tensor& decoder); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 84add756..eb770d7c 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -3,3 +3,10 @@ # TODO: Don't use import * from .video_decoder_ops import * # noqa + +from ._metadata import ( + ContainerMetadata, + get_video_metadata, + StreamMetadata, + VideoMetadata, +) diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py new file mode 100644 index 00000000..d7e90c42 --- /dev/null +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -0,0 +1,108 @@ +import json + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from torchcodec.decoders._core.video_decoder_ops import ( + _get_container_json_metadata, + _get_stream_json_metadata, +) + + +@dataclass +class ContainerMetadata: + duration_seconds: Optional[float] + bit_rate: Optional[float] + best_video_stream_index: Optional[int] + best_audio_stream_index: Optional[int] + + +@dataclass +class StreamMetadata: + duration_seconds: Optional[float] + bit_rate: Optional[float] + # TODO Comment from Nicolas: + # Looking at this, it's not immediately obvious to me that "retrieved" means + # "less accurate than 'computed'". + # Are we open to different names? E.g. "num_frames_from_header" and "num_frames_accurate"? + num_frames_retrieved: Optional[int] + num_frames_computed: Optional[int] + min_pts_seconds: Optional[float] + max_pts_seconds: Optional[float] + codec: Optional[str] + width: Optional[int] + height: Optional[int] + average_fps: Optional[float] + + @property + def num_frames(self) -> Optional[int]: + if self.num_frames_computed is not None: + return self.num_frames_computed + else: + return self.num_frames_retrieved + + +@dataclass +class VideoMetadata: + container: ContainerMetadata + streams: List[StreamMetadata] + + @property + def duration_seconds(self) -> Optional[float]: + if ( + self.container.best_video_stream_index is not None + and self.streams[self.container.best_video_stream_index].duration_seconds + is not None + ): + return self.streams[self.container.best_video_stream_index].duration_seconds + else: + return self.container.duration_seconds + + @property + def bit_rate(self) -> Optional[float]: + if ( + self.container.best_video_stream_index is not None + and self.streams[self.container.best_video_stream_index].bit_rate + is not None + ): + return self.streams[self.container.best_video_stream_index].bit_rate + else: + return self.contain.bit_rate + + @property + def best_video_stream(self) -> StreamMetadata: + assert self.container.best_video_stream_index is not None + return self.container.streams[self.container.best_video_stream_index] + + +def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: + + container_dict = json.loads(_get_container_json_metadata(decoder)) + container_metadata = ContainerMetadata( + duration_seconds=container_dict.get("durationSeconds"), + bit_rate=container_dict.get("bitRate"), + best_video_stream_index=container_dict.get("bestVideoStreamIndex"), + best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), + ) + + streams_metadata = [] + for stream_index in range(container_dict["numStreams"]): + stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) + streams_metadata.append( + StreamMetadata( + duration_seconds=stream_dict.get("durationSeconds"), + bit_rate=stream_dict.get("bitRate"), + num_frames_retrieved=stream_dict.get("numFrames"), + num_frames_computed=stream_dict.get("numFramesFromScan"), + min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"), + max_pts_seconds=stream_dict.get("maxPtsSecondsFromScan"), + codec=stream_dict.get("codec"), + width=stream_dict.get("width"), + height=stream_dict.get("height"), + average_fps=stream_dict.get("averageFps"), + ) + ) + + return VideoMetadata(container=container_metadata, streams=streams_metadata) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 3aba20e5..f0556fbe 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -58,6 +58,10 @@ def load_torchcodec_extension(): get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default +_get_container_json_metadata = ( + torch.ops.torchcodec_ns.get_container_json_metadata.default +) +_get_stream_json_metadata = torch.ops.torchcodec_ns.get_stream_json_metadata.default # ============================= @@ -134,3 +138,13 @@ def get_frames_at_indices_abstract( @register_fake("torchcodec_ns::get_json_metadata") def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return torch.empty_like("") + + +@register_fake("torchcodec_ns::get_container_json_metadata") +def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str: + return torch.empty_like("") + + +@register_fake("torchcodec_ns::get_stream_json_metadata") +def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str: + return torch.empty_like("") diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 9dac8612..895816ea 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -1,7 +1,8 @@ -import json -from typing import Union +from dataclasses import dataclass +from typing import Optional, Union import torch + from torchcodec.decoders import _core as core @@ -24,12 +25,10 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): core.add_video_stream(self._decoder) - # TODO: We should either implement specific core library function to - # retrieve these values, or replace this with a non-JSON metadata - # retrieval. - metadata_json = json.loads(core.get_json_metadata(self._decoder)) - self._num_frames = metadata_json["numFrames"] - self._stream_index = metadata_json["bestVideoStreamIndex"] + self.metadata = _get_and_validate_simple_video_metadata(self._decoder) + # Note: these fields exist and are not None, as validated in _get_and_validate_simple_video_metadata(). + self._num_frames = self.metadata.stream.num_frames_computed + self._stream_index = self.metadata.container.best_video_stream_index def __len__(self) -> int: return self._num_frames @@ -61,3 +60,42 @@ def __next__(self) -> torch.Tensor: return core.get_next_frame(self._decoder) except RuntimeError: raise StopIteration() + + +@dataclass +class SimpleVideoMetadata: + # TODO: ContainerMetadata and StreamMetadata should be publicly available. + # Right now they're only exposed in _core. + container: core.ContainerMetadata + stream: core.StreamMetadata + + # TODO: is the return really supposed to be Optional + @property + def duration_seconds(self) -> Optional[float]: + return self.stream.duration_seconds + + @property + def bit_rate(self) -> Optional[float]: + return self.stream.bit_rate + + +def _get_and_validate_simple_video_metadata( + decoder: torch.Tensor, +) -> SimpleVideoMetadata: + video_metadata = core.get_video_metadata(decoder) + container_metadata = video_metadata.container + + if container_metadata.best_video_stream_index is None: + raise ValueError( + "The best video stream is unknown. This should never happen. " + "Please report an issue following the steps on " + ) + + stream_metadata = video_metadata.streams[container_metadata.best_video_stream_index] + if stream_metadata.num_frames_computed is None: + raise ValueError( + "The number of frames is unknown. This should never happen. " + "Please report an issue following the steps on " + ) + + return SimpleVideoMetadata(container=container_metadata, stream=stream_metadata) From 9c9ff07a9b99abfc66503241cdca5a313037d33a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jun 2024 16:09:58 +0100 Subject: [PATCH 02/11] more stuff --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 3 --- src/torchcodec/decoders/_simple_video_decoder.py | 5 ++--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 1b1cd508..5562e17c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -265,7 +265,6 @@ std::string get_container_json_metadata(at::Tensor &decoder) { return mapToJson(map); } - std::string get_stream_json_metadata(at::Tensor &decoder, int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); @@ -310,8 +309,6 @@ std::string get_stream_json_metadata(at::Tensor &decoder, return mapToJson(map); } - - TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 895816ea..1cd6f3a1 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -69,7 +69,6 @@ class SimpleVideoMetadata: container: core.ContainerMetadata stream: core.StreamMetadata - # TODO: is the return really supposed to be Optional @property def duration_seconds(self) -> Optional[float]: return self.stream.duration_seconds @@ -88,14 +87,14 @@ def _get_and_validate_simple_video_metadata( if container_metadata.best_video_stream_index is None: raise ValueError( "The best video stream is unknown. This should never happen. " - "Please report an issue following the steps on " + "Please report an issue following the steps in " ) stream_metadata = video_metadata.streams[container_metadata.best_video_stream_index] if stream_metadata.num_frames_computed is None: raise ValueError( "The number of frames is unknown. This should never happen. " - "Please report an issue following the steps on " + "Please report an issue following the steps in " ) return SimpleVideoMetadata(container=container_metadata, stream=stream_metadata) From f1013d51c416e371746ab504e333556d4564ca95 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jun 2024 16:16:17 +0100 Subject: [PATCH 03/11] TODO --- src/torchcodec/decoders/_core/_metadata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index d7e90c42..fa71a9f1 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -94,6 +94,7 @@ def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: StreamMetadata( duration_seconds=stream_dict.get("durationSeconds"), bit_rate=stream_dict.get("bitRate"), + # TODO: We should align the C++ names and the json keys with the Python names num_frames_retrieved=stream_dict.get("numFrames"), num_frames_computed=stream_dict.get("numFramesFromScan"), min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"), From 5deb9870cff0725b53e0a0e8cd6ef7c7af2828c0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jun 2024 16:40:57 +0100 Subject: [PATCH 04/11] Add basic test --- .../decoders/_simple_video_decoder.py | 4 +-- test/decoders/test_metadata.py | 27 +++++++++++++++++++ test/decoders/video_decoder_ops_test.py | 7 +++++ 3 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 test/decoders/test_metadata.py diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 1cd6f3a1..37e0d3a6 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -71,11 +71,11 @@ class SimpleVideoMetadata: @property def duration_seconds(self) -> Optional[float]: - return self.stream.duration_seconds + return self.stream.duration_seconds or self.container.duration_seconds @property def bit_rate(self) -> Optional[float]: - return self.stream.bit_rate + return self.stream.bit_rate or self.container.bit_rate def _get_and_validate_simple_video_metadata( diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py new file mode 100644 index 00000000..61f31b67 --- /dev/null +++ b/test/decoders/test_metadata.py @@ -0,0 +1,27 @@ +import pytest + +from torchcodec.decoders._core import ( + create_from_file, + get_video_metadata, +) + +from ..test_utils import get_reference_video_path + + +def test_get_video_metadata(): + decoder = create_from_file(str(get_reference_video_path())) + metadata = get_video_metadata(decoder) + assert len(metadata.streams) == 6 + assert metadata.container.best_video_stream_index == 3 + assert metadata.container.best_audio_stream_index == 3 + + assert metadata.container.duration_seconds == pytest.approx(16.57, abs=0.001) + assert metadata.container.bit_rate == 324915 + + best_stream_metadata = metadata.streams[metadata.container.best_video_stream_index] + assert best_stream_metadata.duration_seconds == pytest.approx(13.013, abs=0.001) + assert best_stream_metadata.bit_rate == 128783 + assert best_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) + assert best_stream_metadata.codec == "h264" + assert best_stream_metadata.num_frames_computed == 390 + assert best_stream_metadata.num_frames_retrieved == 390 diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index 06c3e2eb..823dd124 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -186,6 +186,13 @@ def test_create_decoder(self, create_from): reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") assert_equal(frame_time6, reference_frame_time6) + # TODO: Keeping the metadata tests below for now, but we should remove them + # once we remove get_json_metadata(). + # Note that the distinction made between test_video_get_json_metadata and + # test_video_get_json_metadata_with_stream is misleading: all of the stream + # metadata are available even without adding a video stream, because we + # always call scanFileAndUpdateMetadataAndIndex() when creating a decoder + # from the core API. def test_video_get_json_metadata(self): decoder = create_from_file(str(get_reference_video_path())) metadata = get_json_metadata(decoder) From 8119e51b3c65f02cd0c7b854775dfdf975f8c925 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jun 2024 10:17:03 +0100 Subject: [PATCH 05/11] Comments, linter --- .../decoders/_core/VideoDecoderOps.cpp | 20 ++++++++++++------- test/decoders/test_metadata.py | 5 +---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 5562e17c..b667b4fa 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -33,7 +33,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); - m.def("get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); + m.def( + "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); } // ============================== @@ -145,8 +146,12 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +// TODO: we should use a more robust way to serialize the metadata. There are a +// few alternatives, but ultimately we are limited to what custom ops allow us +// to return. Current ideas are to use a proper JSON library, or to pack all the +// info into tensors. *If* we're OK to drop the export support for metadata, we +// could also easily bind the C++ structs to Python with pybind11. std::string mapToJson(const std::map& metadataMap) { - std::stringstream ss; ss << "{\n"; auto it = metadataMap.begin(); @@ -234,8 +239,8 @@ std::string get_json_metadata(at::Tensor& decoder) { return mapToJson(metadataMap); } -std::string get_container_json_metadata(at::Tensor &decoder) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); +std::string get_container_json_metadata(at::Tensor& decoder) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); auto containerMetadata = videoDecoder->getContainerMetadata(); @@ -265,9 +270,10 @@ std::string get_container_json_metadata(at::Tensor &decoder) { return mapToJson(map); } -std::string get_stream_json_metadata(at::Tensor &decoder, - int64_t stream_index) { - auto videoDecoder = static_cast(decoder.mutable_data_ptr()); +std::string get_stream_json_metadata( + at::Tensor& decoder, + int64_t stream_index) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); auto streamMetadata = videoDecoder->getContainerMetadata().streams[stream_index]; diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 61f31b67..494d646b 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -1,9 +1,6 @@ import pytest -from torchcodec.decoders._core import ( - create_from_file, - get_video_metadata, -) +from torchcodec.decoders._core import create_from_file, get_video_metadata from ..test_utils import get_reference_video_path From bc235eddd4ccf40a58e3b7d6023a7ef2eab330b6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jun 2024 17:57:53 +0100 Subject: [PATCH 06/11] Collapse ContainerMetadata into VideoMetadata and only use StreamMetadata within SimpleVideoDecoder --- src/torchcodec/decoders/_core/VideoDecoder.h | 5 -- src/torchcodec/decoders/_core/__init__.py | 7 +-- src/torchcodec/decoders/_core/_metadata.py | 50 +++++++++---------- .../decoders/_simple_video_decoder.py | 34 +++---------- test/decoders/test_metadata.py | 10 ++-- 5 files changed, 37 insertions(+), 69 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8c0033af..39cd93b7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -101,11 +101,6 @@ class VideoDecoder { std::optional height; }; struct ContainerMetadata { - // TODO: in C++ the StreamMetadata vec is part of the ContainerMetadata. In - // Python, the equivalent list isn't part of the containers' metadata: it is - // a separate attribute of the VideoMetaData dataclass, next to the - // container metadata. We can probably align the C++ structure to reflect - // the Python one? std::vector streams; int numAudioStreams = 0; int numVideoStreams = 0; diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index eb770d7c..56560208 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -4,9 +4,4 @@ from .video_decoder_ops import * # noqa -from ._metadata import ( - ContainerMetadata, - get_video_metadata, - StreamMetadata, - VideoMetadata, -) +from ._metadata import get_video_metadata, StreamMetadata, VideoMetadata diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index fa71a9f1..01ec0f9e 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -11,14 +11,6 @@ ) -@dataclass -class ContainerMetadata: - duration_seconds: Optional[float] - bit_rate: Optional[float] - best_video_stream_index: Optional[int] - best_audio_stream_index: Optional[int] - - @dataclass class StreamMetadata: duration_seconds: Optional[float] @@ -35,6 +27,7 @@ class StreamMetadata: width: Optional[int] height: Optional[int] average_fps: Optional[float] + stream_index: int @property def num_frames(self) -> Optional[int]: @@ -46,30 +39,33 @@ def num_frames(self) -> Optional[int]: @dataclass class VideoMetadata: - container: ContainerMetadata + # TODO: Is 'container' an FFmpeg term? + container_duration_seconds: Optional[float] + container_bit_rate: Optional[float] + best_video_stream_index: Optional[int] + best_audio_stream_index: Optional[int] + streams: List[StreamMetadata] @property def duration_seconds(self) -> Optional[float]: if ( - self.container.best_video_stream_index is not None - and self.streams[self.container.best_video_stream_index].duration_seconds - is not None + self.best_video_stream_index is not None + and self.streams[self.best_video_stream_index].duration_seconds is not None ): - return self.streams[self.container.best_video_stream_index].duration_seconds + return self.streams[self.best_video_stream_index].duration_seconds else: - return self.container.duration_seconds + return self.container_duration_seconds @property def bit_rate(self) -> Optional[float]: if ( - self.container.best_video_stream_index is not None - and self.streams[self.container.best_video_stream_index].bit_rate - is not None + self.best_video_stream_index is not None + and self.streams[self.best_video_stream_index].bit_rate is not None ): - return self.streams[self.container.best_video_stream_index].bit_rate + return self.streams[self.best_video_stream_index].bit_rate else: - return self.contain.bit_rate + return self.container_bit_rate @property def best_video_stream(self) -> StreamMetadata: @@ -80,13 +76,6 @@ def best_video_stream(self) -> StreamMetadata: def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: container_dict = json.loads(_get_container_json_metadata(decoder)) - container_metadata = ContainerMetadata( - duration_seconds=container_dict.get("durationSeconds"), - bit_rate=container_dict.get("bitRate"), - best_video_stream_index=container_dict.get("bestVideoStreamIndex"), - best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), - ) - streams_metadata = [] for stream_index in range(container_dict["numStreams"]): stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) @@ -103,7 +92,14 @@ def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: width=stream_dict.get("width"), height=stream_dict.get("height"), average_fps=stream_dict.get("averageFps"), + stream_index=stream_index, ) ) - return VideoMetadata(container=container_metadata, streams=streams_metadata) + return VideoMetadata( + container_duration_seconds=container_dict.get("durationSeconds"), + container_bit_rate=container_dict.get("bitRate"), + best_video_stream_index=container_dict.get("bestVideoStreamIndex"), + best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), + streams=streams_metadata, + ) diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 37e0d3a6..a3e8317a 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -1,5 +1,4 @@ -from dataclasses import dataclass -from typing import Optional, Union +from typing import Union import torch @@ -27,8 +26,8 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): self.metadata = _get_and_validate_simple_video_metadata(self._decoder) # Note: these fields exist and are not None, as validated in _get_and_validate_simple_video_metadata(). - self._num_frames = self.metadata.stream.num_frames_computed - self._stream_index = self.metadata.container.best_video_stream_index + self._num_frames = self.metadata.num_frames_computed + self._stream_index = self.metadata.stream_index def __len__(self) -> int: return self._num_frames @@ -62,39 +61,22 @@ def __next__(self) -> torch.Tensor: raise StopIteration() -@dataclass -class SimpleVideoMetadata: - # TODO: ContainerMetadata and StreamMetadata should be publicly available. - # Right now they're only exposed in _core. - container: core.ContainerMetadata - stream: core.StreamMetadata - - @property - def duration_seconds(self) -> Optional[float]: - return self.stream.duration_seconds or self.container.duration_seconds - - @property - def bit_rate(self) -> Optional[float]: - return self.stream.bit_rate or self.container.bit_rate - - def _get_and_validate_simple_video_metadata( decoder: torch.Tensor, -) -> SimpleVideoMetadata: +) -> core.StreamMetadata: video_metadata = core.get_video_metadata(decoder) - container_metadata = video_metadata.container - if container_metadata.best_video_stream_index is None: + if video_metadata.best_video_stream_index is None: raise ValueError( "The best video stream is unknown. This should never happen. " "Please report an issue following the steps in " ) - stream_metadata = video_metadata.streams[container_metadata.best_video_stream_index] - if stream_metadata.num_frames_computed is None: + best_stream_metadata = video_metadata.streams[video_metadata.best_video_stream_index] + if best_stream_metadata.num_frames_computed is None: raise ValueError( "The number of frames is unknown. This should never happen. " "Please report an issue following the steps in " ) - return SimpleVideoMetadata(container=container_metadata, stream=stream_metadata) + return best_stream_metadata diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 494d646b..d3fb552b 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -9,13 +9,13 @@ def test_get_video_metadata(): decoder = create_from_file(str(get_reference_video_path())) metadata = get_video_metadata(decoder) assert len(metadata.streams) == 6 - assert metadata.container.best_video_stream_index == 3 - assert metadata.container.best_audio_stream_index == 3 + assert metadata.best_video_stream_index == 3 + assert metadata.best_audio_stream_index == 3 - assert metadata.container.duration_seconds == pytest.approx(16.57, abs=0.001) - assert metadata.container.bit_rate == 324915 + assert metadata.container_duration_seconds == pytest.approx(16.57, abs=0.001) + assert metadata.container_bit_rate == 324915 - best_stream_metadata = metadata.streams[metadata.container.best_video_stream_index] + best_stream_metadata = metadata.streams[metadata.best_video_stream_index] assert best_stream_metadata.duration_seconds == pytest.approx(13.013, abs=0.001) assert best_stream_metadata.bit_rate == 128783 assert best_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) From 8aa6d041cc7bb580afe6c96eff29da9f2ad5b099 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jun 2024 18:08:11 +0100 Subject: [PATCH 07/11] Renaming --- src/torchcodec/decoders/_simple_video_decoder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index a3e8317a..d5b334c7 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -24,10 +24,10 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): core.add_video_stream(self._decoder) - self.metadata = _get_and_validate_simple_video_metadata(self._decoder) + self.stream_metadata = _get_and_validate_stream_metadata(self._decoder) # Note: these fields exist and are not None, as validated in _get_and_validate_simple_video_metadata(). - self._num_frames = self.metadata.num_frames_computed - self._stream_index = self.metadata.stream_index + self._num_frames = self.stream_metadata.num_frames_computed + self._stream_index = self.stream_metadata.stream_index def __len__(self) -> int: return self._num_frames @@ -61,9 +61,7 @@ def __next__(self) -> torch.Tensor: raise StopIteration() -def _get_and_validate_simple_video_metadata( - decoder: torch.Tensor, -) -> core.StreamMetadata: +def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata: video_metadata = core.get_video_metadata(decoder) if video_metadata.best_video_stream_index is None: @@ -72,7 +70,9 @@ def _get_and_validate_simple_video_metadata( "Please report an issue following the steps in " ) - best_stream_metadata = video_metadata.streams[video_metadata.best_video_stream_index] + best_stream_metadata = video_metadata.streams[ + video_metadata.best_video_stream_index + ] if best_stream_metadata.num_frames_computed is None: raise ValueError( "The number of frames is unknown. This should never happen. " From 6bdef454ec9cefbf62668d2bca7132ffefd96a06 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Jun 2024 09:28:54 +0100 Subject: [PATCH 08/11] cleanups --- .../decoders/_core/VideoDecoderOps.cpp | 2 -- src/torchcodec/decoders/_core/_metadata.py | 19 ++++++++----------- .../decoders/_simple_video_decoder.py | 1 - test/decoders/test_metadata.py | 4 ++-- 4 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index c19d1f88..198cc810 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -277,8 +277,6 @@ std::string get_container_json_metadata(at::Tensor& decoder) { std::to_string(*containerMetadata.bestAudioStreamIndex); } - // TODO: Q from Nicolas - is there a better way to retrieve and propagate the - // number of streams? map["numStreams"] = std::to_string(containerMetadata.streams.size()); return mapToJson(map); diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 01ec0f9e..57528edc 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -15,10 +15,8 @@ class StreamMetadata: duration_seconds: Optional[float] bit_rate: Optional[float] - # TODO Comment from Nicolas: - # Looking at this, it's not immediately obvious to me that "retrieved" means - # "less accurate than 'computed'". - # Are we open to different names? E.g. "num_frames_from_header" and "num_frames_accurate"? + # TODO: Before release, we should come up with names that better convey the + # " 'fast and potentially inaccurate' vs 'slower but accurate' " tradeoff. num_frames_retrieved: Optional[int] num_frames_computed: Optional[int] min_pts_seconds: Optional[float] @@ -39,9 +37,8 @@ def num_frames(self) -> Optional[int]: @dataclass class VideoMetadata: - # TODO: Is 'container' an FFmpeg term? - container_duration_seconds: Optional[float] - container_bit_rate: Optional[float] + duration_seconds_container: Optional[float] + bit_rate_container: Optional[float] best_video_stream_index: Optional[int] best_audio_stream_index: Optional[int] @@ -55,7 +52,7 @@ def duration_seconds(self) -> Optional[float]: ): return self.streams[self.best_video_stream_index].duration_seconds else: - return self.container_duration_seconds + return self.duration_seconds_container @property def bit_rate(self) -> Optional[float]: @@ -65,7 +62,7 @@ def bit_rate(self) -> Optional[float]: ): return self.streams[self.best_video_stream_index].bit_rate else: - return self.container_bit_rate + return self.bit_rate_container @property def best_video_stream(self) -> StreamMetadata: @@ -97,8 +94,8 @@ def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: ) return VideoMetadata( - container_duration_seconds=container_dict.get("durationSeconds"), - container_bit_rate=container_dict.get("bitRate"), + duration_seconds_container=container_dict.get("durationSeconds"), + bit_rate_container=container_dict.get("bitRate"), best_video_stream_index=container_dict.get("bestVideoStreamIndex"), best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), streams=streams_metadata, diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index d5b334c7..a441ea22 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -25,7 +25,6 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): core.add_video_stream(self._decoder) self.stream_metadata = _get_and_validate_stream_metadata(self._decoder) - # Note: these fields exist and are not None, as validated in _get_and_validate_simple_video_metadata(). self._num_frames = self.stream_metadata.num_frames_computed self._stream_index = self.stream_metadata.stream_index diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index d3fb552b..303bb795 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -12,8 +12,8 @@ def test_get_video_metadata(): assert metadata.best_video_stream_index == 3 assert metadata.best_audio_stream_index == 3 - assert metadata.container_duration_seconds == pytest.approx(16.57, abs=0.001) - assert metadata.container_bit_rate == 324915 + assert metadata.duration_seconds_container == pytest.approx(16.57, abs=0.001) + assert metadata.bit_rate_container == 324915 best_stream_metadata = metadata.streams[metadata.best_video_stream_index] assert best_stream_metadata.duration_seconds == pytest.approx(13.013, abs=0.001) From 02fd2289ddc7be0f108a80919e237183e30b72fd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Jun 2024 09:53:21 +0100 Subject: [PATCH 09/11] More tests --- src/torchcodec/decoders/_core/_metadata.py | 21 +++-------- test/decoders/simple_video_decoder_test.py | 43 ++++++++++++---------- test/decoders/test_metadata.py | 43 ++++++++++++++++++++-- 3 files changed, 68 insertions(+), 39 deletions(-) diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 57528edc..4a205795 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -46,28 +46,17 @@ class VideoMetadata: @property def duration_seconds(self) -> Optional[float]: - if ( - self.best_video_stream_index is not None - and self.streams[self.best_video_stream_index].duration_seconds is not None - ): - return self.streams[self.best_video_stream_index].duration_seconds - else: - return self.duration_seconds_container + raise NotImplementedError("TODO: decide on logic and implement this!") @property def bit_rate(self) -> Optional[float]: - if ( - self.best_video_stream_index is not None - and self.streams[self.best_video_stream_index].bit_rate is not None - ): - return self.streams[self.best_video_stream_index].bit_rate - else: - return self.bit_rate_container + raise NotImplementedError("TODO: decide on logic and implement this!") @property def best_video_stream(self) -> StreamMetadata: - assert self.container.best_video_stream_index is not None - return self.container.streams[self.container.best_video_stream_index] + if self.best_video_stream_index is None: + raise ValueError("The best video stream is unknown.") + return self.streams[self.best_video_stream_index] def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index 2ee5c3e7..083a11af 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -1,6 +1,6 @@ import pytest -from torchcodec.decoders import SimpleVideoDecoder +from torchcodec.decoders import _core, SimpleVideoDecoder from ..test_utils import ( assert_equal, @@ -11,25 +11,28 @@ class TestSimpleDecoder: - - def test_create_from_file(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) - assert len(decoder) == 390 - assert decoder._stream_index == 3 - - def test_create_from_tensor(self): - decoder = SimpleVideoDecoder(get_reference_video_tensor()) - assert len(decoder) == 390 - assert decoder._stream_index == 3 - - def test_create_from_bytes(self): - path = str(get_reference_video_path()) - with open(path, "rb") as f: - video_bytes = f.read() - - decoder = SimpleVideoDecoder(video_bytes) - assert len(decoder) == 390 - assert decoder._stream_index == 3 + @pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes")) + def test_create(self, source_kind): + if source_kind == "path": + source = str(get_reference_video_path()) + elif source_kind == "tensor": + source = get_reference_video_tensor() + elif source_kind == "bytes": + path = str(get_reference_video_path()) + with open(path, "rb") as f: + source = f.read() + else: + raise ValueError("Oops, double check the parametrization of this test!") + + decoder = SimpleVideoDecoder(source) + assert isinstance(decoder.stream_metadata, _core.StreamMetadata) + assert ( + len(decoder) + == decoder._num_frames + == decoder.stream_metadata.num_frames_computed + == 390 + ) + assert decoder._stream_index == decoder.stream_metadata.stream_index == 3 def test_create_fails(self): with pytest.raises(TypeError, match="Unknown source type"): diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 303bb795..321dd640 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -1,6 +1,10 @@ import pytest -from torchcodec.decoders._core import create_from_file, get_video_metadata +from torchcodec.decoders._core import ( + create_from_file, + get_video_metadata, + StreamMetadata, +) from ..test_utils import get_reference_video_path @@ -12,13 +16,46 @@ def test_get_video_metadata(): assert metadata.best_video_stream_index == 3 assert metadata.best_audio_stream_index == 3 - assert metadata.duration_seconds_container == pytest.approx(16.57, abs=0.001) - assert metadata.bit_rate_container == 324915 + with pytest.raises(NotImplementedError, match="TODO: decide on logic"): + metadata.duration_seconds + with pytest.raises(NotImplementedError, match="TODO: decide on logic"): + metadata.bit_rate + + # TODO: put these checks back once D58974580 is landed. The expected values + # are different depending on the FFmpeg version. + # assert metadata.duration_seconds_container == pytest.approx(16.57, abs=0.001) + # assert metadata.bit_rate_container == 324915 best_stream_metadata = metadata.streams[metadata.best_video_stream_index] + assert best_stream_metadata is metadata.best_video_stream assert best_stream_metadata.duration_seconds == pytest.approx(13.013, abs=0.001) assert best_stream_metadata.bit_rate == 128783 assert best_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) assert best_stream_metadata.codec == "h264" assert best_stream_metadata.num_frames_computed == 390 assert best_stream_metadata.num_frames_retrieved == 390 + + +@pytest.mark.parametrize( + "num_frames_retrieved, num_frames_computed, expected_num_frames", + [(None, 10, 10), (10, None, 10), (None, None, None)], +) +def test_num_frames_fallback( + num_frames_retrieved, num_frames_computed, expected_num_frames +): + """Check that num_frames_computed always has priority when accessing `.num_frames`""" + metadata = StreamMetadata( + duration_seconds=4, + bit_rate=123, + num_frames_retrieved=num_frames_retrieved, + num_frames_computed=num_frames_computed, + min_pts_seconds=0, + max_pts_seconds=4, + codec="whatever", + width=123, + height=321, + average_fps=30, + stream_index=0, + ) + + assert metadata.num_frames == expected_num_frames From 541b55319c6192df2375f86971190a95c584a0b5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Jun 2024 09:55:00 +0100 Subject: [PATCH 10/11] recomment ufmt --- .pre-commit-config.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d381a5a3..6417cfd1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] - - repo: https://github.com/omnilib/ufmt - rev: v2.6.0 - hooks: - - id: ufmt - additional_dependencies: - - black == 24.4.2 - - usort == 1.0.5 + # - repo: https://github.com/omnilib/ufmt + # rev: v2.6.0 + # hooks: + # - id: ufmt + # additional_dependencies: + # - black == 24.4.2 + # - usort == 1.0.5 - repo: https://github.com/PyCQA/flake8 rev: 7.1.0 From b6a2ba9705a36a964ff064bb88c85a90f9dc35fb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jun 2024 11:49:37 +0100 Subject: [PATCH 11/11] Throw std::out_of_range, add comment --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 8 ++++++-- src/torchcodec/decoders/_core/_metadata.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 198cc810..06ecfec7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -286,8 +286,12 @@ std::string get_stream_json_metadata( at::Tensor& decoder, int64_t stream_index) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); - auto streamMetadata = - videoDecoder->getContainerMetadata().streams[stream_index]; + auto streams = videoDecoder->getContainerMetadata().streams; + if (stream_index < 0 || stream_index >= streams.size()) { + throw std::out_of_range( + "stream_index out of bounds: " + std::to_string(stream_index)); + } + auto streamMetadata = streams[stream_index]; std::map map; diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 4a205795..8f0442ab 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -35,6 +35,7 @@ def num_frames(self) -> Optional[int]: return self.num_frames_retrieved +# This may be renamed into e.g. ContainerMetadata in the future to be more generic. @dataclass class VideoMetadata: duration_seconds_container: Optional[float]