diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c3c2ebb0c..8255b89412 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add support for MVTec LOCO AD dataset and sPRO metric by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1686 - 🚀 Update OpenVINO and ONNX export to support fixed input shape by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2006 - Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018 - 🚀 Add compression and quantization for OpenVINO export by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2052 diff --git a/configs/README.md b/configs/README.md index bd3ecb618e..4334eea678 100644 --- a/configs/README.md +++ b/configs/README.md @@ -15,6 +15,7 @@ configs/ │ ├── kolektor.yaml │ ├── mvtec_3d.yaml │ ├── mvtec.yaml +│ ├── mvtec_loco.yaml │ ├── shanghaitec.yaml │ ├── ucsd_ped.yaml │ └── visa.yaml diff --git a/configs/data/mvtec_loco.yaml b/configs/data/mvtec_loco.yaml new file mode 100644 index 0000000000..92c04542c0 --- /dev/null +++ b/configs/data/mvtec_loco.yaml @@ -0,0 +1,13 @@ +class_path: anomalib.data.MVTecLoco +init_args: + root: ./datasets/MVTec_LOCO + category: breakfast_box + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + task: SEGMENTATION + test_split_mode: FROM_DIR + test_split_ratio: 0.2 + val_split_mode: FROM_DIR + val_split_ratio: 0.5 + seed: null diff --git a/docs/source/markdown/guides/reference/data/image/index.md b/docs/source/markdown/guides/reference/data/image/index.md index 2525d0d914..ee75fa0daf 100644 --- a/docs/source/markdown/guides/reference/data/image/index.md +++ b/docs/source/markdown/guides/reference/data/image/index.md @@ -30,6 +30,13 @@ Learn more about Kolektor dataset. Learn more about MVTec 2D dataset ::: +:::{grid-item-card} MVTec LOCO +:link: ./mvtec_loco +:link-type: doc + +Learn more about MVTec LOCO dataset +::: + :::{grid-item-card} Visa :link: ./visa :link-type: doc @@ -47,5 +54,6 @@ Learn more about Visa dataset. ./folder ./kolektor ./mvtec +./mvtec_loco ./visa ``` diff --git a/docs/source/markdown/guides/reference/data/image/mvtec_loco.md b/docs/source/markdown/guides/reference/data/image/mvtec_loco.md new file mode 100644 index 0000000000..8854f27b80 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/image/mvtec_loco.md @@ -0,0 +1,7 @@ +# MVTec LOCO Data + +```{eval-rst} +.. automodule:: anomalib.data.image.mvtec_loco + :members: + :show-inheritance: +``` diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 6a7b173272..8c32e6ec40 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -13,7 +13,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from anomalib import TaskType -from anomalib.metrics import AnomalibMetricCollection, create_metric_collection +from anomalib.metrics import create_metric_collection from anomalib.models import AnomalyModule logger = logging.getLogger(__name__) @@ -67,8 +67,7 @@ def setup( pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. stage (str | None, optional): fit, validate, test or predict. Defaults to None. """ - del trainer, stage # These variables are not used. - + del stage, trainer # this variable is not used. image_metric_names = [] if self.image_metric_names is None else self.image_metric_names if isinstance(image_metric_names, str): image_metric_names = [image_metric_names] @@ -85,9 +84,25 @@ def setup( ) else: pixel_metric_names = ( - self.pixel_metric_names if not isinstance(self.pixel_metric_names, str) else [self.pixel_metric_names] + self.pixel_metric_names.copy() + if not isinstance(self.pixel_metric_names, str) + else [self.pixel_metric_names] ) + # create a separate metric collection for metrics that operate over the semantic segmentation mask + # (segmentation mask with a separate channel for each defect type) + semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = [] + # currently only SPRO metric is supported as semantic segmentation metric + if "SPRO" in pixel_metric_names: + if isinstance(pixel_metric_names, list): + pixel_metric_names.remove("SPRO") + semantic_pixel_metric_names = ["SPRO"] + elif isinstance(pixel_metric_names, dict): + spro_metric = pixel_metric_names.pop("SPRO") + semantic_pixel_metric_names = {"SPRO": spro_metric} + else: + logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names)) + if isinstance(pl_module, AnomalyModule): pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint @@ -97,6 +112,7 @@ def setup( pl_module.pixel_metrics.add_metrics(new_metrics[name]) else: pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") + pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_") self._set_threshold(pl_module) def on_validation_epoch_start( @@ -108,6 +124,7 @@ def on_validation_epoch_start( pl_module.image_metrics.reset() pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() def on_validation_batch_end( self, @@ -122,7 +139,7 @@ def on_validation_batch_end( if outputs is not None: self._outputs_to_device(outputs) - self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) + self._update_metrics(pl_module, outputs) def on_validation_epoch_end( self, @@ -143,6 +160,7 @@ def on_test_epoch_start( pl_module.image_metrics.reset() pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() def on_test_batch_end( self, @@ -157,7 +175,7 @@ def on_test_batch_end( if outputs is not None: self._outputs_to_device(outputs) - self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) + self._update_metrics(pl_module, outputs) def on_test_epoch_end( self, @@ -171,18 +189,21 @@ def on_test_epoch_end( def _set_threshold(self, pl_module: AnomalyModule) -> None: pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) + pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) def _update_metrics( self, - image_metric: AnomalibMetricCollection, - pixel_metric: AnomalibMetricCollection, + pl_module: AnomalyModule, output: STEP_OUTPUT, ) -> None: - image_metric.to(self.device) - image_metric.update(output["pred_scores"], output["label"].int()) + pl_module.image_metrics.to(self.device) + pl_module.image_metrics.update(output["pred_scores"], output["label"].int()) if "mask" in output and "anomaly_maps" in output: - pixel_metric.to(self.device) - pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + pl_module.pixel_metrics.to(self.device) + pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + if "semantic_mask" in output and "anomaly_maps" in output: + pl_module.semantic_pixel_metrics.to(self.device) + pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: if isinstance(output, dict): @@ -190,13 +211,16 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any output[key] = self._outputs_to_device(value) elif isinstance(output, torch.Tensor): output = output.to(self.device) + elif isinstance(output, list): + for i, value in enumerate(output): + output[i] = self._outputs_to_device(value) return output @staticmethod def _log_metrics(pl_module: AnomalyModule) -> None: """Log computed performance metrics.""" - if pl_module.pixel_metrics._update_called: # noqa: SLF001 - pl_module.log_dict(pl_module.pixel_metrics, prog_bar=True) - pl_module.log_dict(pl_module.image_metrics, prog_bar=False) - else: - pl_module.log_dict(pl_module.image_metrics, prog_bar=True) + pl_module.log_dict(pl_module.image_metrics, prog_bar=True) + if pl_module.pixel_metrics.update_called: + pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False) + if pl_module.semantic_pixel_metrics.update_called: + pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False) diff --git a/src/anomalib/callbacks/normalization/min_max_normalization.py b/src/anomalib/callbacks/normalization/min_max_normalization.py index 4ff8c9b6e5..f22a36afd0 100644 --- a/src/anomalib/callbacks/normalization/min_max_normalization.py +++ b/src/anomalib/callbacks/normalization/min_max_normalization.py @@ -39,7 +39,7 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: """Call when the test begins.""" del trainer # `trainer` variable is not used. - for metric in (pl_module.image_metrics, pl_module.pixel_metrics): + for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics): if metric is not None: metric.set_threshold(0.5) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index b619b8317c..85fa9cd52a 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -141,8 +141,17 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: parser.add_function_arguments(get_normalization_callback, "normalization") parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) - parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) - parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) + parser.add_argument( + "--metrics.image", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=["F1Score", "AUROC"], + ) + parser.add_argument( + "--metrics.pixel", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=None, + required=False, + ) parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold") parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False) if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"): diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 85a4fd1589..9a58f5e589 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -15,7 +15,7 @@ from .base import AnomalibDataModule, AnomalibDataset from .depth import DepthDataFormat, Folder3D, MVTec3D -from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa +from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa from .predict import PredictDataset from .utils import LabelName from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat @@ -63,6 +63,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: "Kolektor", "MVTec", "MVTec3D", + "MVTecLoco", "Avenue", "UCSDped", "ShanghaiTech", diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index a3ab4a5c72..bc5063d8ab 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -28,7 +28,9 @@ def collate_fn(batch: list) -> dict[str, Any]: """Collate bounding boxes as lists. - Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries. + Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists, + the `mask_path` is also collated as a list since each element in the batch could be unequal. + For all other entries, the default collate function is used. Args: batch (List): list of items in the batch where len(batch) is equal to the batch size. @@ -42,6 +44,12 @@ def collate_fn(batch: list) -> dict[str, Any]: if "boxes" in elem: # collate boxes as list out_dict["boxes"] = [item.pop("boxes") for item in batch] + if "semantic_mask" in elem: + # semantic masks have a variable number of channels, so we collate them as a list + out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch] + if "mask_path" in elem and isinstance(elem["mask_path"], list): + # collate mask paths as list + out_dict["mask_path"] = [item.pop("mask_path") for item in batch] # collate other data normally out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem}) return out_dict @@ -213,6 +221,12 @@ def _create_val_split(self) -> None: # converted from random training sample self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) + elif self.val_split_mode == ValSplitMode.FROM_DIR: + # the val_data is prepared in subclass + assert hasattr( + self, + "val_data", + ), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup." elif self.val_split_mode != ValSplitMode.NONE: msg = f"Unknown validation split mode: {self.val_split_mode}" raise ValueError(msg) diff --git a/src/anomalib/data/image/__init__.py b/src/anomalib/data/image/__init__.py index 6b30ac9ac8..4db05b51df 100644 --- a/src/anomalib/data/image/__init__.py +++ b/src/anomalib/data/image/__init__.py @@ -13,6 +13,7 @@ from .folder import Folder from .kolektor import Kolektor from .mvtec import MVTec +from .mvtec_loco import MVTecLoco from .visa import Visa @@ -21,6 +22,7 @@ class ImageDataFormat(str, Enum): MVTEC = "mvtec" MVTEC_3D = "mvtec_3d" + MVTEC_LOCO = "mvtec_loco" BTECH = "btech" KOLEKTOR = "kolektor" FOLDER = "folder" @@ -28,4 +30,4 @@ class ImageDataFormat(str, Enum): VISA = "visa" -__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"] +__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"] diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py new file mode 100644 index 0000000000..4ef2b4ab8e --- /dev/null +++ b/src/anomalib/data/image/mvtec_loco.py @@ -0,0 +1,480 @@ +"""MVTec LOCO AD Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch Lightning + DataModule for the MVTec LOCO AD dataset. If the dataset is not on the file system, + the script downloads and extracts the dataset and create PyTorch data objects. + +License: + MVTec LOCO AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + +References: + - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, and Carsten Steger: + Beyond Dents and Scratches: Logical Constraints in Unsupervised Anomaly Detection and Localization; + in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from collections.abc import Sequence +from pathlib import Path + +import torch +from pandas import DataFrame +from PIL import Image +from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2.functional import to_image +from torchvision.tv_tensors import Mask + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import ( + DownloadInfo, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + masks_to_boxes, + read_image, + validate_path, +) + +logger = logging.getLogger(__name__) + + +IMG_EXTENSIONS = (".png", ".PNG") + +DOWNLOAD_INFO = DownloadInfo( + name="mvtec_loco", + url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701" + "/mvtec_loco_anomaly_detection.tar.xz", + hashsum="9e7c84dba550fd2e59d8e9e231c929c45ba737b6b6a6d3814100f54d63aae687", +) + +CATEGORIES = ( + "breakfast_box", + "juice_bottle", + "pushpins", + "screw_bag", + "splicing_connectors", +) + + +def make_mvtec_loco_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] = IMG_EXTENSIONS, +) -> DataFrame: + """Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure. + + The files are expected to follow the structure: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/image_filename/000.png + + where there can be multiple ground-truth masks for the corresponding anomalous images. + + This function creates a dataframe to store the parsed information based on the following format: + + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | + +===+===============+=======+=========+===============+=======================================+=============+ + | 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 | + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ + + Args: + root (str | Path): Path to dataset + split (str | Split | None): Dataset split (ie., either train or test). + Defaults to ``None``. + extensions (Sequence[str]): List of file extensions to be included in the dataset. + Defaults to ``None``. + + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + + Examples: + The following example shows how to get test samples from MVTec LOCO AD pushpins category: + + >>> root = Path('./MVTec_LOCO') + >>> category = 'pushpins' + >>> path = root / category + >>> samples = make_mvtec_loco_dataset(path, split='test') + """ + root = validate_path(root) + + # Retrieve the image and mask files + samples_list = [] + for f in root.glob("**/*"): + if f.suffix in extensions: + parts = f.parts + # 'ground_truth' and non 'ground_truth' path have a different structure + if "ground_truth" not in parts: + split_folder, label_folder, image_file = parts[-3:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, "", image_path)) + else: + split_folder, label_folder, image_folder, image_file = parts[-4:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, image_folder, image_path)) + + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + samples = DataFrame(samples_list, columns=["path", "split", "label", "image_folder", "image_path"]) + + # Replace validation to Split.VAL.value in the split column + samples["split"] = samples["split"].replace("validation", Split.VAL.value) + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # separate ground-truth masks from samples + mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True) + samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True) + + # Group masks and aggregate the path into a list + mask_samples = ( + mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"] + .agg(list) + .reset_index() + .rename(columns={"image_path": "mask_path"}) + ) + + # assign mask paths to anomalous test images + samples["mask_path"] = "" + samples.loc[ + (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), + "mask_path", + ] = mask_samples.mask_path.to_numpy() + + # validate that the right mask files are associated with the right test images + if len(samples.loc[samples.label_index == LabelName.ABNORMAL]): + image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem) + mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply( + lambda x: {Path(mask_path).parent.stem for mask_path in x}, + ) + + if not all( + next(iter(mask_stems)) == image_stem + for image_stem, mask_stems in zip(image_stems, mask_parent_stems, strict=True) + ): + error_message = ( + "Mismatch between anomalous images and ground truth masks. " + "Make sure the parent folder of the mask files in 'ground_truth' folder " + "follows the same naming convention as the anomalous images in the dataset " + "(e.g., image: '005.png', mask: '005/000.png')." + ) + raise ValueError(error_message) + + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class MVTecLocoDataset(AnomalibDataset): + """MVTec LOCO dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + root (Path | str): Path to the root of the dataset. + Defaults to ``./datasets/MVTec_LOCO``. + category (str): Sub-category of the dataset, e.g. 'breakfast_box' + Defaults to ``breakfast_box``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + split (str | Split | None): Split of the dataset, Split.TRAIN, Split.VAL, or Split.TEST + Defaults to ``None``. + + Examples: + .. code-block:: python + + from anomalib.data.image.mvtec_loco import MVTecLocoDataset + from anomalib.data.utils.transforms import get_transforms + from torchvision.transforms.v2 import Resize + + transform = Resize((256, 256)) + dataset = MVTecLocoDataset( + task="classification", + transform=transform, + root='./datasets/MVTec_LOCO', + category='breakfast_box', + ) + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image']) + + When the task is segmentation, the dataset will also contain the mask: + + .. code-block:: python + + dataset.task = "segmentation" + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W). + + .. code-block:: python + + print(dataset[0]["image"].shape, dataset[0]["mask"].shape) + # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256])) + """ + + def __init__( + self, + task: TaskType, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + transform: Transform | None = None, + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / category + self.split = split + self.samples = make_mvtec_loco_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + ) + + @staticmethod + def _read_mask(mask_path: str | Path) -> Mask: + image = Image.open(mask_path).convert("L") + return Mask(to_image(image).squeeze(), dtype=torch.uint8) + + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Get dataset item for the index ``index``. + + This method is mostly based on the super class implementation, with some different as follows: + - Using 'torch.where' to make sure the 'mask' in the return item is binarized + - An additional 'masks' is added, the non-binary masks with original size for the SPRO metric calculation + Args: + index (int): Index to get the item. + + Returns: + dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path, + target path, image tensor, label and transformed bounding box. + """ + image_path = self.samples.iloc[index].image_path + mask_path = self.samples.iloc[index].mask_path + label_index = self.samples.iloc[index].label_index + + image = read_image(image_path, as_tensor=True) + item = {"image_path": image_path, "label": label_index} + + if self.task == TaskType.CLASSIFICATION: + item["image"] = self.transform(image) if self.transform else image + elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + if isinstance(mask_path, str): + mask_path = [mask_path] + semantic_mask = ( + Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) + if label_index == LabelName.NORMAL + else Mask(torch.stack([self._read_mask(path) for path in mask_path])) + ) + + binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8)) + item["image"], item["mask"] = self.transform(image, binary_mask) if self.transform else (image, binary_mask) + + item["mask_path"] = mask_path + # List of masks with the original size for saturation based metrics calculation + item["semantic_mask"] = semantic_mask + + if self.task == TaskType.DETECTION: + # create boxes from masks for detection task + boxes, _ = masks_to_boxes(item["mask"]) + item["boxes"] = boxes[0] + else: + msg = f"Unknown task type: {self.task}" + raise ValueError(msg) + + return item + + +class MVTecLoco(AnomalibDataModule): + """MVTec LOCO Datamodule. + + Args: + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/MVTec_LOCO"``. + category (str): Category of the MVTec LOCO dataset (e.g. "breakfast_box"). + Defaults to ``"breakfast_box"``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + image_size (tuple[int, int], optional): Size to which input images should be resized. + Defaults to ``None``. + transform (Transform, optional): Transforms that should be applied to the input images. + Defaults to ``None``. + train_transform (Transform, optional): Transforms that should be applied to the input images during training. + Defaults to ``None``. + eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_DIR``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + + Examples: + To create an MVTec LOCO AD datamodule with default settings: + + >>> datamodule = MVTecLoco(root="anomalib/datasets/MVTec_LOCO") + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + To change the category of the dataset: + + >>> datamodule = MVTecLoco(category="pushpins") + + To change the image and batch size: + + >>> datamodule = MVTecLoco(image_size=(512, 512), train_batch_size=16, eval_batch_size=8) + + MVTec LOCO AD dataset provide an independent validation set with normal images only in the 'validation' folder. + If you would like to use a different validation set splitted from train or test set, + you can use the ``val_split_mode`` and ``val_split_ratio`` arguments to create a new validation set. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1) + + This will subsample the test set by 10% and use it as the validation set. + If you would like to create a validation set synthetically that would + not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2) + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType = TaskType.SEGMENTATION, + image_size: tuple[int, int] | None = None, + transform: Transform | None = None, + train_transform: Transform | None = None, + eval_transform: Transform | None = None, + test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode = ValSplitMode.FROM_DIR, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + image_size=image_size, + transform=transform, + train_transform=train_transform, + eval_transform=eval_transform, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + self.task = task + self.root = Path(root) + self.category = category + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets, configs, and perform dynamic subset splitting. + + This method overrides the parent class's method to also setup the val dataset. + The MVTec LOCO dataset provides an independent validation subset. + """ + self.train_data = MVTecLocoDataset( + task=self.task, + transform=self.train_transform, + split=Split.TRAIN, + root=self.root, + category=self.category, + ) + self.val_data = MVTecLocoDataset( + task=self.task, + transform=self.eval_transform, + split=Split.VAL, + root=self.root, + category=self.category, + ) + self.test_data = MVTecLocoDataset( + task=self.task, + transform=self.eval_transform, + split=Split.TEST, + root=self.root, + category=self.category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = MVTecLoco(root="./datasets/MVTec_LOCO", category="breakfast_box") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── MVTec_LOCO + ├── breakfast_box + ├── ... + └── splicing_connectors + """ + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index 9e560691d5..566dba8c28 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -50,6 +50,7 @@ class ValSplitMode(str, Enum): FROM_TRAIN = "from_train" FROM_TEST = "from_test" SYNTHETIC = "synthetic" + FROM_DIR = "from_dir" def concatenate_datasets(datasets: Sequence["data.AnomalibDataset"]) -> "data.AnomalibDataset": diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 4c3eafa811..3a59b7a846 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -21,6 +21,7 @@ from .min_max import MinMax from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO +from .spro import SPRO from .threshold import F1AdaptiveThreshold, ManualThreshold __all__ = [ @@ -35,6 +36,7 @@ "ManualThreshold", "MinMax", "PRO", + "SPRO", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index 020aebd9e9..47c17a3a44 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -3,8 +3,12 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging + from torchmetrics import MetricCollection +logger = logging.getLogger(__name__) + class AnomalibMetricCollection(MetricCollection): """Extends the MetricCollection class for use in the Anomalib pipeline.""" @@ -21,6 +25,10 @@ def set_threshold(self, threshold_value: float) -> None: if hasattr(metric, "threshold"): metric.threshold = threshold_value + def set_update_called(self, val: bool) -> None: + """Set the flag indicating whether the update method has been called.""" + self._update_called = val + def update(self, *args, **kwargs) -> None: """Add data to the metrics.""" super().update(*args, **kwargs) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py new file mode 100644 index 0000000000..c59091ee5f --- /dev/null +++ b/src/anomalib/metrics/spro.py @@ -0,0 +1,215 @@ +"""Implementation of SPRO metric based on TorchMetrics.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +from pathlib import Path +from typing import Any + +import torch +from torchmetrics import Metric + +from anomalib.data.utils import validate_path + +logger = logging.getLogger(__name__) + + +class SPRO(Metric): + """Saturated Per-Region Overlap (SPRO) Score. + + This metric computes the macro average of the saturated per-region overlap between the + predicted anomaly masks and the ground truth masks. + + Args: + threshold (float): Threshold used to binarize the predictions. + Defaults to ``0.5``. + saturation_config (str | Path): Path to the saturation configuration file. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. + kwargs: Additional arguments to the TorchMetrics base class. + + Example: + Import the metric from the package: + + >>> import torch + >>> from anomalib.metrics import SPRO + + Create random ``preds`` and ``labels`` tensors: + + >>> labels = torch.randint(low=0, high=2, size=(2, 10, 5), dtype=torch.float32) + >>> labels = [labels] + >>> preds = torch.rand_like(labels[0][:1]) + + Compute the SPRO score for labels and preds: + + >>> spro = SPRO(threshold=0.5) + >>> spro.update(preds, labels) + >>> spro.compute() + tensor(0.6333) + + .. note:: + Note that the example above shows random predictions and labels. + Therefore, the SPRO score above may not be reproducible. + + """ + + def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None + if self.saturation_config is None: + logger.warning( + "The saturation_config attribute is empty, the threshold is set to the defect area." + "This is equivalent to PRO metric but with the 'region' are separated by mask files", + ) + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None: + """Compute the SPRO score for the current batch. + + Args: + predictions (torch.Tensor): Predicted anomaly masks. + masks (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. + + Example: + To update the metric state for the current batch, use the ``update`` method: + + >>> spro.update(preds, labels) + """ + score, total = spro_score( + predictions=predictions, + targets=masks, + threshold=self.threshold, + saturation_config=self.saturation_config, + ) + self.score += score + self.total += total + + def compute(self) -> torch.Tensor: + """Compute the macro average of the SPRO score across all masks in all batches. + + Example: + To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: + + >>> spro.compute() + tensor(0.5433) + """ + if self.total == 0: # only background/normal images + return torch.Tensor([1.0]) + return self.score / self.total + + +def spro_score( + predictions: torch.Tensor, + targets: list[torch.Tensor], + threshold: float = 0.5, + saturation_config: dict | None = None, +) -> torch.Tensor: + """Calculate the SPRO score for a batch of predictions. + + Args: + predictions (torch.Tensor): Predicted anomaly masks. + targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + Defaults: ``0.5``. + saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. + + Returns: + torch.Tensor: Scalar value representing the average SPRO score for the input batch. + """ + # Add batch dim if not exist + if len(predictions.shape) == 2: + predictions = predictions.unsqueeze(0) + + # Resize the prediction to have the same size as the target mask + predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets[0].shape[-2:]) + + # Apply threshold to binary predictions + if predictions.dtype == torch.float: + predictions = predictions > threshold + + score = torch.tensor(0.0) + total = 0 + # Iterate for each image in the batch + for i, target in enumerate(targets): + # Iterate for each ground-truth mask per image + for mask in target: + label = torch.max(mask) + if label == 0: # Skip if only normal/background + continue + # Calculate true positive + target_per_label = mask == label + true_pos = torch.sum(predictions[i] & target_per_label) + + # Calculate the anomalous area of the ground-truth + defect_area = torch.sum(target_per_label) + + if saturation_config is not None: + # Adjust saturation threshold based on configuration + saturation_per_label = saturation_config[label.int().item()] + saturation_threshold = saturation_per_label["saturation_threshold"] + + if saturation_per_label["relative_saturation"]: + saturation_threshold *= defect_area + + # Check if threshold is larger than defect area + if saturation_threshold > defect_area: + warning_msg = ( + f"Saturation threshold for label {label.int().item()} is larger than defect area. " + "Setting it to defect area." + ) + logger.warning(warning_msg) + saturation_threshold = defect_area + else: + # Handle case when saturation_config is empty + saturation_threshold = defect_area + + # Update score with minimum of true_pos/saturation_threshold and 1.0 + score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) + total += 1 + return score, total + + +def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: + """Load saturation configurations from a JSON file. + + Args: + config_path (str | Path): Path to the saturation configuration file. + + Returns: + Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. + + Example JSON format in the config file of MVTec LOCO dataset: + [ + { + "defect_name": "1_additional_pushpin", + "pixel_value": 255, + "saturation_threshold": 6300, + "relative_saturation": false + }, + { + "defect_name": "2_additional_pushpins", + "pixel_value": 254, + "saturation_threshold": 12600, + "relative_saturation": false + }, + ... + ] + """ + try: + config_path = validate_path(config_path) + with Path.open(config_path) as file: + configs = json.load(file) + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + except FileNotFoundError: + logger.warning("The saturation config file %s does not exist. Returning None.", config_path) + return None diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index 471db72f15..3034300748 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -52,6 +52,7 @@ def __init__(self) -> None: self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection + self.semantic_pixel_metrics: AnomalibMetricCollection self._transform: Transform | None = None self._input_size: tuple[int, int] | None = None diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 51b683acab..4af678a829 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -216,6 +216,7 @@ def __init__( data_format: DataFormat | str, root: Path | str | None = None, num_train: int = 5, + num_val: int = 5, num_test: int = 5, seed: int | None = None, ) -> None: @@ -230,6 +231,7 @@ def __init__( self.root = Path(mkdtemp() if root is None else root) self.dataset_root = self.root / self.data_format.value self.num_train = num_train + self.num_val = num_val self.num_test = num_test self.rng = np.random.default_rng(seed) @@ -298,6 +300,7 @@ def __init__( normal_category: str = "good", abnormal_category: str = "bad", num_train: int = 5, + num_val: int = 5, num_test: int = 5, image_shape: tuple[int, int] = (256, 256), num_channels: int = 3, @@ -308,6 +311,7 @@ def __init__( data_format=data_format, root=root, num_train=num_train, + num_val=num_val, num_test=num_test, seed=seed, ) @@ -389,6 +393,41 @@ def _generate_dummy_mvtec_3d_dataset(self) -> None: else: self.image_generator.save_image(filename=filename, image=image) + def _generate_dummy_mvtec_loco_dataset(self) -> None: + """Generates dummy MVTec LOCO AD dataset in a temporary directory using the same convention as MVTec LOCO AD.""" + # MVTec LOCO has multiple subcategories within the dataset. + dataset_category = "dummy" + + extension = ".png" + + # Create normal images. + for split in ("train", "validation", "test"): + path = self.dataset_root / dataset_category / split / "good" + if split == "train": + num_images = self.num_train + elif split == "val": + num_images = self.num_val + else: + num_images = self.num_test + + for i in range(num_images): + label = LabelName.NORMAL + image_filename = path / f"{i:03}{extension}" + self.image_generator.generate_image(label=label, image_filename=image_filename) + + # Create abnormal test images and masks. + for abnormal_dir in ("logical_anomalies", "structural_anomalies"): + path = self.dataset_root / dataset_category / "test" / abnormal_dir + mask_path = self.dataset_root / dataset_category / "ground_truth" / abnormal_dir + + for i in range(self.num_test): + label = LabelName.ABNORMAL + image_filename = path / f"{i:03}{extension}" + # Here, only one ground-truth mask for each abnormal image is generated + # the structure follows the same convention as MVTec LOCO AD, e.g., image_filename/000.png + mask_filename = mask_path / f"{i:03}/000{extension}" + self.image_generator.generate_image(label, image_filename, mask_filename) + def _generate_dummy_kolektor_dataset(self) -> None: """Generate dummy Kolektor dataset in directory using the same convention as Kolektor AD.""" # Emulating the first two categories of Kolektor dataset. diff --git a/tests/unit/data/image/test_mvtec_loco.py b/tests/unit/data/image/test_mvtec_loco.py new file mode 100644 index 0000000000..64275319fe --- /dev/null +++ b/tests/unit/data/image/test_mvtec_loco.py @@ -0,0 +1,39 @@ +"""Unit Tests - MVTecLoco Datamodule.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest + +from anomalib import TaskType +from anomalib.data import MVTecLoco +from tests.unit.data.base.image import _TestAnomalibImageDatamodule + + +class TestMVTecLoco(_TestAnomalibImageDatamodule): + """MVTecLoco Datamodule Unit Tests.""" + + @pytest.fixture() + def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTecLoco: + """Create and return a MVTecLoco datamodule.""" + _datamodule = MVTecLoco( + root=dataset_path / "mvtec_loco", + category="dummy", + task=task_type, + image_size=256, + train_batch_size=4, + eval_batch_size=4, + ) + _datamodule.prepare_data() + _datamodule.setup() + + return _datamodule + + def test_mask_is_binary(self, datamodule: MVTecLoco) -> None: + """Test if the mask tensor is binary.""" + if datamodule.test_data.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + mask_tensor = datamodule.test_data[0]["mask"] + is_binary = (mask_tensor.eq(0) | mask_tensor.eq(1)).all() + assert is_binary.item() is True diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py new file mode 100644 index 0000000000..b7dd8e043d --- /dev/null +++ b/tests/unit/metrics/test_spro.py @@ -0,0 +1,83 @@ +"""Test SPRO metric.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import pathlib +import tempfile + +import torch + +from anomalib.metrics.spro import SPRO + + +def test_spro() -> None: + """Checks if SPRO metric computes the score utilizing the given saturation configs.""" + saturation_config = [ + { + "pixel_value": 255, + "saturation_threshold": 10, + "relative_saturation": False, + }, + { + "pixel_value": 254, + "saturation_threshold": 0.5, + "relative_saturation": True, + }, + ] + + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump(saturation_config, f) + saturation_config_json = f.name + + masks = [ + torch.Tensor( + [ + [ + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + ], + ], + ), + ] + + masks[0][0] *= 255 + masks[0][1] *= 254 + + preds = (torch.arange(8) / 10) + 0.05 + # metrics receive squeezed predictions (N, H, W) + preds = preds.unsqueeze(1).repeat(1, 5).view(1, 8, 5) + + thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + targets = [1.0, 1.0, 1.0, 0.75, 0.0, 0.0] + targets_wo_saturation = [1.0, 0.625, 0.5, 0.375, 0.0, 0.0] + for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True): + # test using saturation_cofig + spro = SPRO(threshold=threshold, saturation_config=saturation_config_json) + spro.update(preds, masks) + assert spro.compute() == target + + # test without saturation_config + spro_wo_saturaton = SPRO(threshold=threshold) + spro_wo_saturaton.update(preds, masks) + assert spro_wo_saturaton.compute() == target_wo_saturation + + # Remove the temporary config file + pathlib.Path(saturation_config_json).unlink()