Skip to content

Commit

Permalink
Add option for the user to pass in ffmpeg thread count
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadsharif1 committed Oct 24, 2024
1 parent 9d7b240 commit e7c614d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 10 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class VideoDecoder:
This can be either "NCHW" (default) or "NHWC", where N is the batch
size, C is the number of channels, H is the height, and W is the
width of the frames.
num_ffmpeg_threads (int, optional): The number of threads to use for decoding.
Use 1 for single-threaded decoding which is best if you are running multiple
instances of ``VideoDecoder`` in parallel. Use a higher number for multi-threaded
decoding which is best if you are running a single instance of ``VideoDecoder``.
Default: 1.
.. note::
Expand All @@ -58,6 +63,7 @@ def __init__(
*,
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
num_ffmpeg_threads: int = 1,
):
if isinstance(source, str):
self._decoder = core.create_from_file(source)
Expand All @@ -82,7 +88,10 @@ def __init__(

core.scan_all_streams_to_update_metadata(self._decoder)
core.add_video_stream(
self._decoder, stream_index=stream_index, dimension_order=dimension_order
self._decoder,
stream_index=stream_index,
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
)

self.metadata, self.stream_index = _get_and_validate_stream_metadata(
Expand Down
5 changes: 3 additions & 2 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def test_create_fails(self):
with pytest.raises(ValueError, match="No valid stream found"):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa

def test_getitem_int(self):
decoder = VideoDecoder(NASA_VIDEO.path)
@pytest.mark.parametrize("num_ffmpeg_threads", ("int", 1, 4))
def test_getitem_int(self, num_ffmpeg_threads):
decoder = VideoDecoder(NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads)

ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0)
ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1)
Expand Down

0 comments on commit e7c614d

Please sign in to comment.