Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement probe_video_metadata_from_header #68

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,9 @@

from .video_decoder_ops import * # noqa

from ._metadata import get_video_metadata, StreamMetadata, VideoMetadata
from ._metadata import (
get_video_metadata,
get_video_metadata_from_header,
StreamMetadata,
VideoMetadata,
)
16 changes: 14 additions & 2 deletions src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import json
import pathlib

from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union

import torch

from torchcodec.decoders._core.video_decoder_ops import (
_get_container_json_metadata,
_get_stream_json_metadata,
create_from_file,
)


Expand Down Expand Up @@ -83,6 +85,11 @@ def best_video_stream(self) -> StreamMetadata:


def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
"""Return video metadata from a video decoder.

The accuracy of the metadata and the availability of some returned fields
depends on whether a full scan was performed by the decoder.
"""

container_dict = json.loads(_get_container_json_metadata(decoder))
streams_metadata = []
Expand All @@ -92,7 +99,8 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
StreamMetadata(
duration_seconds=stream_dict.get("durationSeconds"),
bit_rate=stream_dict.get("bitRate"),
# TODO_OPEN_ISSUE: We should align the C++ names and the json keys with the Python names
# TODO_OPEN_ISSUE: We should align the C++ names and the json
# keys with the Python names
num_frames_retrieved=stream_dict.get("numFrames"),
num_frames_computed=stream_dict.get("numFramesFromScan"),
min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"),
Expand All @@ -112,3 +120,7 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata:
best_audio_stream_index=container_dict.get("bestAudioStreamIndex"),
streams=streams_metadata,
)


def get_video_metadata_from_header(filename: Union[str, pathlib.Path]) -> VideoMetadata:
return get_video_metadata(create_from_file(str(filename)))
34 changes: 29 additions & 5 deletions test/decoders/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
import functools

import pytest

from torchcodec.decoders._core import (
create_from_file,
get_ffmpeg_library_versions,
get_video_metadata,
get_video_metadata_from_header,
scan_all_streams_to_update_metadata,
StreamMetadata,
)

from ..utils import NASA_VIDEO


def test_get_video_metadata():
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
metadata = get_video_metadata(decoder)
def _get_video_metadata(path, with_scan: bool):
decoder = create_from_file(str(path))
if with_scan:
scan_all_streams_to_update_metadata(decoder)
return get_video_metadata(decoder)


@pytest.mark.parametrize(
"metadata_getter",
(
get_video_metadata_from_header,
functools.partial(_get_video_metadata, with_scan=False),
functools.partial(_get_video_metadata, with_scan=True),
),
)
def test_get_metadata(metadata_getter):
with_scan = (
metadata_getter.keywords["with_scan"]
if isinstance(metadata_getter, functools.partial)
else False
)

metadata = metadata_getter(NASA_VIDEO.path)

assert len(metadata.streams) == 6
assert metadata.best_video_stream_index == 3
assert metadata.best_audio_stream_index == 4
Expand Down Expand Up @@ -43,8 +66,9 @@ def test_get_video_metadata():
assert best_stream_metadata.bit_rate == 128783
assert best_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001)
assert best_stream_metadata.codec == "h264"
assert best_stream_metadata.num_frames_computed == 390
assert best_stream_metadata.num_frames_computed == (390 if with_scan else None)
assert best_stream_metadata.num_frames_retrieved == 390
assert best_stream_metadata.num_frames == 390


@pytest.mark.parametrize(
Expand Down
5 changes: 0 additions & 5 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,6 @@ def test_create_decoder(self, create_from):

# Keeping the metadata tests below for now, but we should remove them
# once we remove get_json_metadata().
# Note that the distinction made between test_video_get_json_metadata and
# test_video_get_json_metadata_with_stream is misleading: all of the stream
# metadata are available even without adding a video stream, because we
# always call scanFileAndUpdateMetadataAndIndex() when creating a decoder
# from the core API.
def test_video_get_json_metadata(self):
decoder = create_from_file(str(NASA_VIDEO.path))
metadata = get_json_metadata(decoder)
Expand Down
Loading