Skip to content

Commit

Permalink
Allow decode_image to support paths
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Sep 3, 2024
1 parent c36025a commit 08a8df3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 31 deletions.
8 changes: 7 additions & 1 deletion docs/source/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.
:toctree: generated/
:template: function.rst

read_image
decode_image
encode_jpeg
decode_jpeg
Expand All @@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported.

ImageReadMode

Obsolete decoding function:

.. autosummary::
:toctree: generated/
:template: class.rst

read_image


Video
Expand Down
21 changes: 21 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted):
img += 123 # make sure image buffer wasn't freed by underlying decoding lib


@pytest.mark.parametrize("input_type", ("Path", "str", "tensor"))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_image_path(input_type, scripted):
# Check that decode_image can support not just tensors as input
path = next(get_images(IMAGE_ROOT, ".jpg"))
if input_type == "Path":
input = Path(path)
elif input_type == "str":
input = path
elif input_type == "tensor":
input = read_file(path)
else:
raise ValueError("Oops")

if scripted and input_type == "Path":
pytest.xfail(reason="Can't pass a Path when scripting")

decode_fun = torch.jit.script(decode_image) if scripted else decode_image
decode_fun(input)


if __name__ == "__main__":
pytest.main([__file__])
40 changes: 10 additions & 30 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):


def decode_image(
input: torch.Tensor,
input: Union[torch.Tensor, str],

Check warning on line 280 in torchvision/io/image.py

View workflow job for this annotation

GitHub Actions / bc

Function decode_image: input changed from torch.Tensor to Union[torch.Tensor, str]
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the
appropriate operation to decode the image into a Tensor.
"""Decode an image into a tensor.
Currently supported image formats are jpeg, png, gif and webp.
The values of the output tensor are in uint8 in [0, 255] for most cases.
Expand All @@ -295,8 +295,9 @@ def decode_image(
tensor.
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
image.
input (Tensor or str or ``pathlib.Path``): The image to decode. If a
tensor is passed, it must be one dimensional uint8 tensor containing
the raw bytes of the image. Otherwise, this must be a path to the image file.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
Expand All @@ -309,6 +310,8 @@ def decode_image(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
if not isinstance(input, torch.Tensor):
input = read_file(str(input))
output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation)
return output

Expand All @@ -318,30 +321,7 @@ def read_image(
mode: ImageReadMode = ImageReadMode.UNCHANGED,
apply_exif_orientation: bool = False,
) -> torch.Tensor:
"""
Reads a JPEG, PNG, WEBP, or GIF image into a Tensor.
The values of the output tensor are in uint8 in [0, 255] for most cases.
If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.
Args:
path (str or ``pathlib.Path``): path of the image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes. Only applies to JPEG and PNG images.
apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor.
Only applies to JPEG and PNG images. Default: False.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
"""[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead."""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
Expand Down

0 comments on commit 08a8df3

Please sign in to comment.