From 9771c23cf01f39416cd649575a3329f7d696a5f7 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 13 Jun 2024 12:41:55 -0700 Subject: [PATCH] [torchcodec] refactor test utils into its own library (#25) Summary: Pull Request resolved: https://github.com/pytorch-labs/torchcodec/pull/25 Refactors utility functions that were defined directly in test source files into a library that can be shared among tests. This diff: 1. Creates a Python test utility library that can be imported by Python tests. 2. Moves the test resources to the test top-level. 3. Defines a C++ target for those resources. This is in preparation for adding new tests that will need this library. Differential Revision: D58530481 --- test/decoders/VideoDecoderOpsTest.cpp | 3 +- test/decoders/VideoDecoderTest.cpp | 3 +- test/decoders/manual_smoke_test.py | 2 +- test/decoders/video_decoder_ops_test.py | 43 +++------------ test/{decoders => }/resources/nasa_13013.mp4 | Bin .../resources/nasa_13013.mp4.audio.mp3 | Bin .../resources/nasa_13013.mp4.frame000001.pt | Bin .../resources/nasa_13013.mp4.frame000002.pt | Bin .../resources/nasa_13013.mp4.time10.000000.pt | Bin .../resources/nasa_13013.mp4.time12.979633.pt | Bin .../resources/nasa_13013.mp4.time6.000000.pt | Bin .../resources/nasa_13013.mp4.time6.100000.pt | Bin test/samplers/video_clip_sampler_test.py | 22 +------- test/test_utils.py | 52 ++++++++++++++++++ 14 files changed, 65 insertions(+), 60 deletions(-) rename test/{decoders => }/resources/nasa_13013.mp4 (100%) rename test/{decoders => }/resources/nasa_13013.mp4.audio.mp3 (100%) rename test/{decoders => }/resources/nasa_13013.mp4.frame000001.pt (100%) rename test/{decoders => }/resources/nasa_13013.mp4.frame000002.pt (100%) rename test/{decoders => }/resources/nasa_13013.mp4.time10.000000.pt (100%) rename test/{decoders => }/resources/nasa_13013.mp4.time12.979633.pt (100%) rename test/{decoders => }/resources/nasa_13013.mp4.time6.000000.pt (100%) rename test/{decoders => }/resources/nasa_13013.mp4.time6.100000.pt (100%) create mode 100644 test/test_utils.py diff --git a/test/decoders/VideoDecoderOpsTest.cpp b/test/decoders/VideoDecoderOpsTest.cpp index f23d28f4..18d7f53f 100644 --- a/test/decoders/VideoDecoderOpsTest.cpp +++ b/test/decoders/VideoDecoderOpsTest.cpp @@ -17,8 +17,7 @@ namespace facebook::torchcodec { std::string getResourcePath(const std::string& filename) { #ifdef FBCODE_BUILD - std::string filepath = - "pytorch/torchcodec/test/decoders/resources/" + filename; + std::string filepath = "pytorch/torchcodec/test/resources/" + filename; filepath = build::getResourcePath(filepath).string(); #else std::filesystem::path dirPath = std::filesystem::path(__FILE__); diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 506791ae..b5db373c 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -24,8 +24,7 @@ namespace facebook::torchcodec { std::string getResourcePath(const std::string& filename) { #ifdef FBCODE_BUILD - std::string filepath = - "pytorch/torchcodec/test/decoders/resources/" + filename; + std::string filepath = "pytorch/torchcodec/test/resources/" + filename; filepath = build::getResourcePath(filepath).string(); #else std::filesystem::path dirPath = std::filesystem::path(__FILE__); diff --git a/test/decoders/manual_smoke_test.py b/test/decoders/manual_smoke_test.py index 07869c54..7df5a790 100644 --- a/test/decoders/manual_smoke_test.py +++ b/test/decoders/manual_smoke_test.py @@ -4,7 +4,7 @@ import torchcodec decoder = torchcodec.decoders._core.create_from_file( - os.path.dirname(__file__) + "/resources/nasa_13013.mp4" + os.path.dirname(__file__) + "../resources/nasa_13013.mp4" ) torchcodec.decoders._core.add_video_stream(decoder, stream_index=3) frame = torchcodec.decoders._core.get_frame_at_index( diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index cf541d46..8e48c273 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -10,7 +10,6 @@ import pytest import torch -import torchvision.transforms as transforms from PIL import Image from torchcodec.decoders._core import ( @@ -26,42 +25,14 @@ seek_to_pts, ) -torch._dynamo.config.capture_dynamic_output_shape_ops = True -IN_FBCODE = os.environ.get("IN_FBCODE_TORCHCODEC") == "1" - - -# TODO: Eventually move that as a common test util -def assert_equal(*args, **kwargs): - torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) - - -# TODO: Eventually move that as a common test util -def get_video_path(filename: str) -> pathlib.Path: - if IN_FBCODE: - resource = ( - importlib.resources.files(__package__) - .joinpath("resources") - .joinpath(filename) - ) - with importlib.resources.as_file(resource) as path: - return path - else: - return pathlib.Path(__file__).parent / "resources" / filename - - -# TODO: make this a fixture or wrap with @functools.lru_cache to avoid -# re-computing? -def load_tensor_from_file(filename: str) -> torch.Tensor: - file_path = get_video_path(filename) - return torch.load(file_path) - - -def get_reference_video_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4") - +from ..test_utils import ( + assert_equal, + get_reference_audio_path, + get_reference_video_path, + load_tensor_from_file, +) -def get_reference_audio_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4.audio.mp3") +torch._dynamo.config.capture_dynamic_output_shape_ops = True class ReferenceDecoder: diff --git a/test/decoders/resources/nasa_13013.mp4 b/test/resources/nasa_13013.mp4 similarity index 100% rename from test/decoders/resources/nasa_13013.mp4 rename to test/resources/nasa_13013.mp4 diff --git a/test/decoders/resources/nasa_13013.mp4.audio.mp3 b/test/resources/nasa_13013.mp4.audio.mp3 similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.audio.mp3 rename to test/resources/nasa_13013.mp4.audio.mp3 diff --git a/test/decoders/resources/nasa_13013.mp4.frame000001.pt b/test/resources/nasa_13013.mp4.frame000001.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.frame000001.pt rename to test/resources/nasa_13013.mp4.frame000001.pt diff --git a/test/decoders/resources/nasa_13013.mp4.frame000002.pt b/test/resources/nasa_13013.mp4.frame000002.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.frame000002.pt rename to test/resources/nasa_13013.mp4.frame000002.pt diff --git a/test/decoders/resources/nasa_13013.mp4.time10.000000.pt b/test/resources/nasa_13013.mp4.time10.000000.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.time10.000000.pt rename to test/resources/nasa_13013.mp4.time10.000000.pt diff --git a/test/decoders/resources/nasa_13013.mp4.time12.979633.pt b/test/resources/nasa_13013.mp4.time12.979633.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.time12.979633.pt rename to test/resources/nasa_13013.mp4.time12.979633.pt diff --git a/test/decoders/resources/nasa_13013.mp4.time6.000000.pt b/test/resources/nasa_13013.mp4.time6.000000.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.time6.000000.pt rename to test/resources/nasa_13013.mp4.time6.000000.pt diff --git a/test/decoders/resources/nasa_13013.mp4.time6.100000.pt b/test/resources/nasa_13013.mp4.time6.100000.pt similarity index 100% rename from test/decoders/resources/nasa_13013.mp4.time6.100000.pt rename to test/resources/nasa_13013.mp4.time6.100000.pt diff --git a/test/samplers/video_clip_sampler_test.py b/test/samplers/video_clip_sampler_test.py index 490db70a..7c7c2237 100644 --- a/test/samplers/video_clip_sampler_test.py +++ b/test/samplers/video_clip_sampler_test.py @@ -15,23 +15,7 @@ VideoClipSampler, ) - -# TODO: move this to a common util -IN_FBCODE = os.environ.get("IN_FBCODE_TORCHCODEC") == "1" - - -# TODO: Eventually rely on common util for this -@pytest.fixture() -def nasa_13013() -> torch.Tensor: - if IN_FBCODE: - video_path = importlib.resources.path(__package__, "nasa_13013.mp4") - else: - video_path = ( - Path(__file__).parent.parent / "decoders" / "resources" / "nasa_13013.mp4" - ) - arr = np.fromfile(video_path, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - return video_tensor +from ..test_utils import assert_equal, nasa_13013 # noqa: F401; see nasa_13013 use @pytest.mark.parametrize( @@ -51,13 +35,13 @@ def nasa_13013() -> torch.Tensor: ), ], ) -def test_sampler(sampler_args, nasa_13013): +def test_sampler(sampler_args, nasa_13013): # noqa: F811; linter does not see this as a use torch.manual_seed(0) desired_width, desired_height = 320, 240 video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height) sampler = VideoClipSampler(video_args, sampler_args) clips = sampler(nasa_13013) - assert len(clips) == sampler_args.clips_per_video + assert_equal(len(clips), sampler_args.clips_per_video) clip = clips[0] if isinstance(sampler_args, TimeBasedSamplerArgs): # TODO FIXME: Looks like we have an API inconsistency. diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..6eb0ec12 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,52 @@ +import importlib +import os +import pathlib + +import numpy as np +import pytest + +import torch + + +def in_fbcode() -> bool: + return os.environ.get("IN_FBCODE_TORCHCODEC") == "1" + + +IN_FBCODE = in_fbcode() + + +def assert_equal(*args, **kwargs): + torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) + + +def get_video_path(filename: str) -> pathlib.Path: + if IN_FBCODE: + resource = ( + importlib.resources.files(__spec__.parent) + .joinpath("resources") + .joinpath(filename) + ) + with importlib.resources.as_file(resource) as path: + return path + else: + return pathlib.Path(__file__).parent / "resources" / filename + + +def get_reference_video_path() -> pathlib.Path: + return get_video_path("nasa_13013.mp4") + + +def get_reference_audio_path() -> pathlib.Path: + return get_video_path("nasa_13013.mp4.audio.mp3") + + +def load_tensor_from_file(filename: str) -> torch.Tensor: + file_path = get_video_path(filename) + return torch.load(file_path) + + +@pytest.fixture() +def nasa_13013() -> torch.Tensor: + arr = np.fromfile(get_reference_video_path(), dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + return video_tensor