Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchcodec] initial version of simple video decoder #42

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files
args: ['--maxkb=1000']

- repo: https://github.com/omnilib/ufmt
rev: v2.6.0
hooks:
- id: ufmt
additional_dependencies:
- black == 24.4.2
- usort == 1.0.5
# - repo: https://github.com/omnilib/ufmt
# rev: v2.6.0
# hooks:
# - id: ufmt
# additional_dependencies:
# - black == 24.4.2
# - usort == 1.0.5

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._simple_video_decoder import SimpleVideoDecoder # noqa
62 changes: 62 additions & 0 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from typing import Union

import torch
from torchcodec.decoders import _core as core


class SimpleVideoDecoder:

def __init__(self, source: Union[str, bytes, torch.Tensor]):
# TODO: support Path objects.
if isinstance(source, str):
self._decoder = core.create_from_file(source)
elif isinstance(source, bytes):
self._decoder = core.create_from_bytes(source)
elif isinstance(source, torch.Tensor):
self._decoder = core.create_from_tensor(source)
else:
raise TypeError(
f"Unknown source type: {type(source)}. "
"Supported types are str, bytes and Tensor."
)

core.add_video_stream(self._decoder)

# TODO: We should either implement specific core library function to
# retrieve these values, or replace this with a non-JSON metadata
# retrieval.
metadata_json = json.loads(core.get_json_metadata(self._decoder))
self._num_frames = metadata_json["numFrames"]
self._stream_index = metadata_json["bestVideoStreamIndex"]

def __len__(self) -> int:
return self._num_frames

def __getitem__(self, key: int) -> torch.Tensor:
if not isinstance(key, int):
raise TypeError(
f"Unsupported key type: {type(key)}. Supported type is int."
)

if key < 0:
key += self._num_frames
if key >= self._num_frames or key < 0:
raise IndexError(
f"Index {key} is out of bounds; length is {self._num_frames}"
)

return core.get_frame_at_index(
self._decoder, frame_index=key, stream_index=self._stream_index
)

def __iter__(self) -> "SimpleVideoDecoder":
return self

def __next__(self) -> torch.Tensor:
# TODO: We should distinguish between expected end-of-file and unexpected
# runtime error.
try:
return core.get_next_frame(self._decoder)
except RuntimeError:
raise StopIteration()
83 changes: 83 additions & 0 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest

from torchcodec.decoders import SimpleVideoDecoder

from ..test_utils import (
assert_equal,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
)


class TestSimpleDecoder:

def test_create_from_file(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_from_tensor(self):
decoder = SimpleVideoDecoder(get_reference_video_tensor())
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_from_bytes(self):
path = str(get_reference_video_path())
with open(path, "rb") as f:
video_bytes = f.read()

decoder = SimpleVideoDecoder(video_bytes)
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_fails(self):
with pytest.raises(TypeError, match="Unknown source type"):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")

assert_equal(ref_frame0, decoder[0])
assert_equal(ref_frame1, decoder[1])
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[-1000] # noqa

with pytest.raises(TypeError, match="Unsupported key type"):
frame = decoder["0"] # noqa

def test_next(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")

for i, frame in enumerate(decoder):
if i == 0:
assert_equal(ref_frame0, frame)
elif i == 1:
assert_equal(ref_frame1, frame)
elif i == 180:
assert_equal(ref_frame180, frame)
elif i == 389:
assert_equal(ref_frame_last, frame)


if __name__ == "__main__":
pytest.main()
Loading