diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 4a03fcbc..d069c949 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -36,6 +36,11 @@ class VideoDecoder: This can be either "NCHW" (default) or "NHWC", where N is the batch size, C is the number of channels, H is the height, and W is the width of the frames. + num_ffmpeg_threads (int, optional): The number of threads to use for decoding. + Use 1 for single-threaded decoding which may be best if you are running multiple + instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded + decoding which is best if you are running a single instance of ``VideoDecoder``. + Default: 1. .. note:: @@ -58,6 +63,7 @@ def __init__( *, stream_index: Optional[int] = None, dimension_order: Literal["NCHW", "NHWC"] = "NCHW", + num_ffmpeg_threads: int = 1, ): if isinstance(source, str): self._decoder = core.create_from_file(source) @@ -82,7 +88,10 @@ def __init__( core.scan_all_streams_to_update_metadata(self._decoder) core.add_video_stream( - self._decoder, stream_index=stream_index, dimension_order=dimension_order + self._decoder, + stream_index=stream_index, + dimension_order=dimension_order, + num_threads=num_ffmpeg_threads, ) self.metadata, self.stream_index = _get_and_validate_stream_metadata( diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 55a8256d..1d5c65e2 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -54,8 +54,9 @@ def test_create_fails(self): with pytest.raises(ValueError, match="No valid stream found"): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa - def test_getitem_int(self): - decoder = VideoDecoder(NASA_VIDEO.path) + @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) + def test_getitem_int(self, num_ffmpeg_threads): + decoder = VideoDecoder(NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1)