From d0ebeb55573820df89fa24f6418b9e624683932d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 3 Sep 2024 17:30:08 +0100 Subject: [PATCH] Allow decode_image to support paths (#8624) --- docs/source/io.rst | 8 +++- docs/source/models.rst | 16 ++++---- .../others/plot_repurposing_annotations.py | 10 ++--- .../others/plot_scripted_tensor_transforms.py | 6 +-- gallery/others/plot_visualization_utils.py | 10 ++--- .../plot_transforms_getting_started.py | 4 +- test/smoke_test.py | 10 ++--- test/test_image.py | 21 ++++++++++ torchvision/io/image.py | 40 +++++-------------- 9 files changed, 66 insertions(+), 59 deletions(-) diff --git a/docs/source/io.rst b/docs/source/io.rst index 638f310bf69..d372091cc6a 100644 --- a/docs/source/io.rst +++ b/docs/source/io.rst @@ -19,7 +19,6 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported. :toctree: generated/ :template: function.rst - read_image decode_image encode_jpeg decode_jpeg @@ -38,6 +37,13 @@ For encoding, JPEG (cpu and CUDA) and PNG are supported. ImageReadMode +Obsolete decoding function: + +.. autosummary:: + :toctree: generated/ + :template: class.rst + + read_image Video diff --git a/docs/source/models.rst b/docs/source/models.rst index 15540778602..53e8d87609e 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -226,10 +226,10 @@ Here is an example of how to use the pre-trained image classification models: .. code:: python - from torchvision.io import read_image + from torchvision.io import decode_image from torchvision.models import resnet50, ResNet50_Weights - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = ResNet50_Weights.DEFAULT @@ -283,10 +283,10 @@ Here is an example of how to use the pre-trained quantized image classification .. code:: python - from torchvision.io import read_image + from torchvision.io import decode_image from torchvision.models.quantization import resnet50, ResNet50_QuantizedWeights - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = ResNet50_QuantizedWeights.DEFAULT @@ -339,11 +339,11 @@ Here is an example of how to use the pre-trained semantic segmentation models: .. code:: python - from torchvision.io.image import read_image + from torchvision.io.image import decode_image from torchvision.models.segmentation import fcn_resnet50, FCN_ResNet50_Weights from torchvision.transforms.functional import to_pil_image - img = read_image("gallery/assets/dog1.jpg") + img = decode_image("gallery/assets/dog1.jpg") # Step 1: Initialize model with the best available weights weights = FCN_ResNet50_Weights.DEFAULT @@ -411,12 +411,12 @@ Here is an example of how to use the pre-trained object detection models: .. code:: python - from torchvision.io.image import read_image + from torchvision.io.image import decode_image from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights from torchvision.utils import draw_bounding_boxes from torchvision.transforms.functional import to_pil_image - img = read_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") + img = decode_image("test/assets/encode_jpeg/grace_hopper_517x606.jpg") # Step 1: Initialize model with the best available weights weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT diff --git a/gallery/others/plot_repurposing_annotations.py b/gallery/others/plot_repurposing_annotations.py index 9d723064ee4..2c2e10ffb2a 100644 --- a/gallery/others/plot_repurposing_annotations.py +++ b/gallery/others/plot_repurposing_annotations.py @@ -66,12 +66,12 @@ def show(imgs): # We will take images and masks from the `PenFudan Dataset `_. -from torchvision.io import read_image +from torchvision.io import decode_image img_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054.png") mask_path = os.path.join(ASSETS_DIRECTORY, "FudanPed00054_mask.png") -img = read_image(img_path) -mask = read_image(mask_path) +img = decode_image(img_path) +mask = decode_image(mask_path) # %% @@ -181,8 +181,8 @@ def __getitem__(self, idx): img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) - img = read_image(img_path) - mask = read_image(mask_path) + img = decode_image(img_path) + mask = decode_image(mask_path) img = F.convert_image_dtype(img, dtype=torch.float) mask = F.convert_image_dtype(mask, dtype=torch.float) diff --git a/gallery/others/plot_scripted_tensor_transforms.py b/gallery/others/plot_scripted_tensor_transforms.py index 5c49a7ca894..da2213347e3 100644 --- a/gallery/others/plot_scripted_tensor_transforms.py +++ b/gallery/others/plot_scripted_tensor_transforms.py @@ -21,7 +21,7 @@ import torch.nn as nn import torchvision.transforms as v1 -from torchvision.io import read_image +from torchvision.io import decode_image plt.rcParams["savefig.bbox"] = 'tight' torch.manual_seed(1) @@ -39,8 +39,8 @@ # :class:`torch.nn.Sequential` instead of # :class:`~torchvision.transforms.v2.Compose`: -dog1 = read_image(str(ASSETS_PATH / 'dog1.jpg')) -dog2 = read_image(str(ASSETS_PATH / 'dog2.jpg')) +dog1 = decode_image(str(ASSETS_PATH / 'dog1.jpg')) +dog2 = decode_image(str(ASSETS_PATH / 'dog2.jpg')) transforms = torch.nn.Sequential( v1.RandomCrop(224), diff --git a/gallery/others/plot_visualization_utils.py b/gallery/others/plot_visualization_utils.py index d0a214a7340..72c35b53717 100644 --- a/gallery/others/plot_visualization_utils.py +++ b/gallery/others/plot_visualization_utils.py @@ -42,11 +42,11 @@ def show(imgs): # image of dtype ``uint8`` as input. from torchvision.utils import make_grid -from torchvision.io import read_image +from torchvision.io import decode_image from pathlib import Path -dog1_int = read_image(str(Path('../assets') / 'dog1.jpg')) -dog2_int = read_image(str(Path('../assets') / 'dog2.jpg')) +dog1_int = decode_image(str(Path('../assets') / 'dog1.jpg')) +dog2_int = decode_image(str(Path('../assets') / 'dog2.jpg')) dog_list = [dog1_int, dog2_int] grid = make_grid(dog_list) @@ -362,9 +362,9 @@ def show(imgs): # from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights -from torchvision.io import read_image +from torchvision.io import decode_image -person_int = read_image(str(Path("../assets") / "person1.jpg")) +person_int = decode_image(str(Path("../assets") / "person1.jpg")) weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT transforms = weights.transforms() diff --git a/gallery/transforms/plot_transforms_getting_started.py b/gallery/transforms/plot_transforms_getting_started.py index 0faf79c46af..2696a9e57e7 100644 --- a/gallery/transforms/plot_transforms_getting_started.py +++ b/gallery/transforms/plot_transforms_getting_started.py @@ -21,14 +21,14 @@ plt.rcParams["savefig.bbox"] = 'tight' from torchvision.transforms import v2 -from torchvision.io import read_image +from torchvision.io import decode_image torch.manual_seed(1) # If you're trying to run that on Colab, you can download the assets and the # helpers from https://github.com/pytorch/vision/tree/main/gallery/ from helpers import plot -img = read_image(str(Path('../assets') / 'astronaut.jpg')) +img = decode_image(str(Path('../assets') / 'astronaut.jpg')) print(f"{type(img) = }, {img.dtype = }, {img.shape = }") # %% diff --git a/test/smoke_test.py b/test/smoke_test.py index f98d019bea5..3a44ae3efe9 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -6,7 +6,7 @@ import torch import torchvision -from torchvision.io import decode_jpeg, decode_webp, read_file, read_image +from torchvision.io import decode_image, decode_jpeg, decode_webp, read_file from torchvision.models import resnet50, ResNet50_Weights @@ -21,13 +21,13 @@ def smoke_test_torchvision() -> None: def smoke_test_torchvision_read_decode() -> None: - img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) + img_jpg = decode_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg")) if img_jpg.shape != (3, 606, 517): raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}") - img_png = read_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) + img_png = decode_image(str(SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png")) if img_png.shape != (4, 471, 354): raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}") - img_webp = read_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) + img_webp = decode_image(str(SCRIPT_DIR / "assets/fakedata/logos/rgb_pytorch.webp")) if img_webp.shape != (3, 100, 100): raise RuntimeError(f"Unexpected shape of img_webp: {img_webp.shape}") @@ -54,7 +54,7 @@ def smoke_test_compile() -> None: def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: - img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) + img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) # Step 1: Initialize model with the best available weights weights = ResNet50_Weights.DEFAULT diff --git a/test/test_image.py b/test/test_image.py index 4d14af638a0..c817b7b831c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -1044,5 +1044,26 @@ def test_decode_heic(decode_fun, scripted): img += 123 # make sure image buffer wasn't freed by underlying decoding lib +@pytest.mark.parametrize("input_type", ("Path", "str", "tensor")) +@pytest.mark.parametrize("scripted", (False, True)) +def test_decode_image_path(input_type, scripted): + # Check that decode_image can support not just tensors as input + path = next(get_images(IMAGE_ROOT, ".jpg")) + if input_type == "Path": + input = Path(path) + elif input_type == "str": + input = path + elif input_type == "tensor": + input = read_file(path) + else: + raise ValueError("Oops") + + if scripted and input_type == "Path": + pytest.xfail(reason="Can't pass a Path when scripting") + + decode_fun = torch.jit.script(decode_image) if scripted else decode_image + decode_fun(input) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchvision/io/image.py b/torchvision/io/image.py index f1df0d52672..8805846df23 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -277,13 +277,13 @@ def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75): def decode_image( - input: torch.Tensor, + input: Union[torch.Tensor, str], mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False, ) -> torch.Tensor: - """ - Detect whether an image is a JPEG, PNG, WEBP, or GIF and performs the - appropriate operation to decode the image into a Tensor. + """Decode an image into a tensor. + + Currently supported image formats are jpeg, png, gif and webp. The values of the output tensor are in uint8 in [0, 255] for most cases. @@ -295,8 +295,9 @@ def decode_image( tensor. Args: - input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the - image. + input (Tensor or str or ``pathlib.Path``): The image to decode. If a + tensor is passed, it must be one dimensional uint8 tensor containing + the raw bytes of the image. Otherwise, this must be a path to the image file. mode (ImageReadMode): the read mode used for optionally converting the image. Default: ``ImageReadMode.UNCHANGED``. See ``ImageReadMode`` class for more information on various @@ -309,6 +310,8 @@ def decode_image( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_image) + if not isinstance(input, torch.Tensor): + input = read_file(str(input)) output = torch.ops.image.decode_image(input, mode.value, apply_exif_orientation) return output @@ -318,30 +321,7 @@ def read_image( mode: ImageReadMode = ImageReadMode.UNCHANGED, apply_exif_orientation: bool = False, ) -> torch.Tensor: - """ - Reads a JPEG, PNG, WEBP, or GIF image into a Tensor. - - The values of the output tensor are in uint8 in [0, 255] for most cases. - - If the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] - (supported from torchvision ``0.21``. Since uint16 support is limited in - pytorch, we recommend calling - :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` - after this function to convert the decoded image into a uint8 or float - tensor. - - Args: - path (str or ``pathlib.Path``): path of the image. - mode (ImageReadMode): the read mode used for optionally converting the image. - Default: ``ImageReadMode.UNCHANGED``. - See ``ImageReadMode`` class for more information on various - available modes. Only applies to JPEG and PNG images. - apply_exif_orientation (bool): apply EXIF orientation transformation to the output tensor. - Only applies to JPEG and PNG images. Default: False. - - Returns: - output (Tensor[image_channels, image_height, image_width]) - """ + """[OBSOLETE] Use :func:`~torchvision.io.decode_image` instead.""" if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(read_image) data = read_file(path)