From c5c6bc7238ba3feff08af70a27c5f008ca04e6c6 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 19:01:45 +0900 Subject: [PATCH 01/63] add FROM_DIR option to `val split mode` to support a provided val directory Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/utils/split.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/data/utils/split.py b/src/anomalib/data/utils/split.py index 27d1b4d770..9838b6251c 100644 --- a/src/anomalib/data/utils/split.py +++ b/src/anomalib/data/utils/split.py @@ -49,6 +49,7 @@ class ValSplitMode(str, Enum): SAME_AS_TEST = "same_as_test" FROM_TEST = "from_test" SYNTHETIC = "synthetic" + FROM_DIR = "from_dir" def concatenate_datasets(datasets: Sequence["data.AnomalibDataset"]) -> "data.AnomalibDataset": From 4c76092850e1dacace0f711466eb8efa8a4c397e Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 19:03:18 +0900 Subject: [PATCH 02/63] add a conditional check for the FROM_DIR option of val split mode Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/base/datamodule.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 3eae2974b7..13975b047c 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -164,6 +164,9 @@ def _create_val_split(self) -> None: # converted from random training sample self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) + elif self.val_split_mode == ValSplitMode.FROM_DIR: + # the val_data is prepared in subclass + pass elif self.val_split_mode != ValSplitMode.NONE: msg = f"Unknown validation split mode: {self.val_split_mode}" raise ValueError(msg) From ecc7169d14dd5c4c09748b73f07a1c2465200f65 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 19:09:47 +0900 Subject: [PATCH 03/63] add the mvtec loco ad dataset classes Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 499 ++++++++++++++++++++++++++ 1 file changed, 499 insertions(+) create mode 100644 src/anomalib/data/image/mvtec_loco.py diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py new file mode 100644 index 0000000000..47d6bc64e2 --- /dev/null +++ b/src/anomalib/data/image/mvtec_loco.py @@ -0,0 +1,499 @@ +"""MVTec LOCO AD Dataset (CC BY-NC-SA 4.0). + +Description: + This script contains PyTorch Dataset, Dataloader and PyTorch Lightning + DataModule for the MVTec LOCO AD dataset. If the dataset is not on the file system, + the script downloads and extracts the dataset and create PyTorch data objects. + +License: + MVTec LOCO AD dataset is released under the Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License + (CC BY-NC-SA 4.0)(https://creativecommons.org/licenses/by-nc-sa/4.0/). + +References: + - Paul Bergmann, Kilian Batzner, Michael Fauser, David Sattlegger, and Carsten Steger: + Beyond Dents and Scratches: Logical Constraints in Unsupervised Anomaly Detection and Localization; + in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 +""" + +import logging +from collections.abc import Sequence +from pathlib import Path + +import albumentations as A # noqa: N812 +import cv2 +import numpy as np +from pandas import DataFrame + +from anomalib import TaskType +from anomalib.data.base import AnomalibDataModule, AnomalibDataset +from anomalib.data.utils import ( + DownloadInfo, + InputNormalizationMethod, + LabelName, + Split, + TestSplitMode, + ValSplitMode, + download_and_extract, + get_transforms, +) + +logger = logging.getLogger(__name__) + + +IMG_EXTENSIONS = (".png", ".PNG") + +DOWNLOAD_INFO = DownloadInfo( + name="mvtec_loco", + url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701" + "/mvtec_loco_anomaly_detection.tar.xz", + checksum="d40f092ac6f88433f609583c4a05f56f", +) + +CATEGORIES = ( + "breakfast_box", + "juice_bottle", + "pushpins", + "screw_bag", + "splicing_connectors", +) + +GT_MERGED_DIR = "ground_truth_merged" + + +def _merge_gt_mask( + root: str | Path, + extensions: Sequence[str] = IMG_EXTENSIONS, + gt_merged_dir: str | Path = GT_MERGED_DIR, +) -> None: + """Merges ground truth masks within specified directories and saves the merged masks. + + Args: + root (str | Path): Root directory containing the 'ground_truth' folder. + extensions (Sequence[str]): Allowed file extensions for ground truth masks. + Default is IMG_EXTENSIONS. + gt_merged_dir (str | Path]): Directory where merged masks will be saved. + Default is GT_MERGED_DIR. + + Returns: + None + + Example: + >>> _merge_gt_mask('path/to/breakfast_box/') + + This function reads ground truth masks from the specified directories, merges them into + a single mask for each corresponding images (e.g. merge 059/000.png and 059/001.png into 059.png), + and saves the merged masks in the default GT_MERGED_DIR directory. + + Note: The merged masks are saved with the same filename structure as the corresponding anomalous image files. + """ + root = Path(root) + gt_mask_paths = {f.parent for f in root.glob("ground_truth/**/*") if f.suffix in extensions} + + for mask_path in gt_mask_paths: + # Merge each mask inside mask_path into a single mask + merged_mask = None + for mask_file in mask_path.glob("*"): + if mask_file.suffix in extensions: + mask = cv2.imread(str(mask_file), cv2.IMREAD_UNCHANGED) + if merged_mask is None: + merged_mask = np.zeros_like(mask) + merged_mask = np.maximum(merged_mask, mask) + + # Binarize masks + merged_mask = np.minimum(merged_mask, 255) + + # Define the path for the new merged mask + _, anomaly_dir, image_filename = mask_path.parts[-3:] + new_mask_path = root / Path(gt_merged_dir) / anomaly_dir / (image_filename + ".png") + + # Create the necessary directories if they do not exist + new_mask_path.parent.mkdir(parents=True, exist_ok=True) + + # Save the merged mask + cv2.imwrite(str(new_mask_path), merged_mask) + + +def make_mvtec_loco_dataset( + root: str | Path, + split: str | Split | None = None, + extensions: Sequence[str] = IMG_EXTENSIONS, + gt_merged_dir: str | Path = GT_MERGED_DIR, +) -> DataFrame: + """Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure. + + The files are expected to follow the structure: + path/to/dataset/split/category/image_filename.png + path/to/dataset/ground_truth/category/image_filename/000.png + + where there can be multiple ground-truth masks for the corresponding anomalous images. + + This function first merges the multiple ground-truth-masks by executing _merge_gt_mask(), + it then creates a dataframe to store the parsed information based on the following format: + + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | + +===+===============+=======+=========+===============+=======================================+=============+ + | 0 | datasets/name | test | defect | filename.png | path/to/merged_masks/filename.png | 1 | + +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ + + Note: the final image_path is converted to full path by combining it with the path, split, and label columns + Example, datasets/name/test/defect/filename.png + + Args: + root (str | Path): Path to dataset + split (str | Split | None): Dataset split (ie., either train or test). + Defaults to ``None``. + extensions (Sequence[str]): List of file extensions to be included in the dataset. + Defaults to ``None``. + gt_merged_dir (str | Path]): Directory where merged masks will be saved. + Default is GT_MERGED_DIR. + + Returns: + DataFrame: an output dataframe containing the samples of the dataset. + + Examples: + The following example shows how to get test samples from MVTec LOCO AD pushpins category: + + >>> root = Path('./MVTec_LOCO') + >>> category = 'pushpins' + >>> path = root / category + >>> samples = make_mvtec_loco_dataset(path, split='test') + """ + root = Path(root) + gt_merged_dir = Path(gt_merged_dir) + + # assert the directory to store the merged ground-truth masks is different than the original gt directory + assert gt_merged_dir != "ground_truth" + + # Merge ground-truth masks for each corresponding images and store into the 'gt_merged_dir' folder + if (root / gt_merged_dir).is_dir(): + logger.info(f"Found the directory of the merged ground-truth masks: {root / gt_merged_dir!s}") + else: + logger.info("Merging the multiple ground-truth masks for each corresponding images.") + _merge_gt_mask(root, gt_merged_dir=gt_merged_dir) + + # Retrieve the image and mask files + samples_list = [] + for f in root.glob("**/*"): + if f.suffix in extensions: + parts = f.parts + # Ignore original 'ground_truth' folder because the 'gt_merged_dir' is used instead + if "ground_truth" not in parts: + split_folder, label_folder, image_path = parts[-3:] + samples_list.append((str(root), split_folder, label_folder, image_path)) + + if not samples_list: + msg = f"Found 0 images in {root}" + raise RuntimeError(msg) + + samples = DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) + + # Modify image_path column by converting to absolute path + samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path + + # Replace validation to Split.VAL.value in the split column + samples["split"] = samples["split"].replace("validation", Split.VAL.value) + + # Create label index for normal (0) and anomalous (1) images. + samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL + samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL + samples.label_index = samples.label_index.astype(int) + + # separate masks from samples + mask_samples = samples.loc[samples.split == str(gt_merged_dir)].sort_values(by="image_path", ignore_index=True) + samples = samples[samples.split != str(gt_merged_dir)].sort_values(by="image_path", ignore_index=True) + + # assign mask paths to anomalous test images + samples["mask_path"] = "" + samples.loc[ + (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), + "mask_path", + ] = mask_samples.image_path.to_numpy() + + # assert that the right mask files are associated with the right test images + if len(samples.loc[samples.label_index == LabelName.ABNORMAL]): + assert ( + samples.loc[samples.label_index == LabelName.ABNORMAL] + .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) + .all() + ), f"Mismatch between anomalous images and ground truth masks. Make sure the mask files in '{gt_merged_dir!s}' \ + folder follow the same naming convention as the anomalous images in the dataset (e.g. image: \ + '000.png', mask: '000.png')." + + if split: + samples = samples[samples.split == split].reset_index(drop=True) + + return samples + + +class MVTecLocoDataset(AnomalibDataset): + """MVTec LOCO dataset class. + + Args: + task (TaskType): Task type, ``classification``, ``detection`` or ``segmentation``. + transform (A.Compose): Albumentations Compose object describing the transforms that are applied to the inputs. + root (Path | str): Path to the root of the dataset. + Defaults to ``./datasets/MVTec_LOCO``. + category (str): Sub-category of the dataset, e.g. 'breakfast_box' + Defaults to ``breakfast_box``. + split (str | Split | None): Split of the dataset, Split.TRAIN, Split.VAL, or Split.TEST + Defaults to ``None``. + + Examples: + .. code-block:: python + + from anomalib.data.image.mvtec_loco import MVTecLocoDataset + from anomalib.data.utils.transforms import get_transforms + + transform = get_transforms(image_size=256) + dataset = MVTecLocoDataset( + task="classification", + transform=transform, + root='./datasets/MVTec_LOCO', + category='breakfast_box', + ) + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image']) + + When the task is segmentation, the dataset will also contain the mask: + + .. code-block:: python + + dataset.task = "segmentation" + dataset.setup() + print(dataset[0].keys()) + # Output: dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + The image is a torch tensor of shape (C, H, W) and the mask is a torch tensor of shape (H, W). + + .. code-block:: python + + print(dataset[0]["image"].shape, dataset[0]["mask"].shape) + # Output: (torch.Size([3, 256, 256]), torch.Size([256, 256])) + """ + + def __init__( + self, + task: TaskType, + transform: A.Compose, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + split: str | Split | None = None, + ) -> None: + super().__init__(task=task, transform=transform) + + self.root_category = Path(root) / Path(category) + self.split = split + + def _setup(self) -> None: + self.samples = make_mvtec_loco_dataset( + self.root_category, + split=self.split, + extensions=IMG_EXTENSIONS, + gt_merged_dir=GT_MERGED_DIR, + ) + + +class MVTecLoco(AnomalibDataModule): + """MVTec LOCO Datamodule. + + Args: + root (Path | str): Path to the root of the dataset. + Defaults to ``"./datasets/MVTec_LOCO"``. + category (str): Category of the MVTec LOCO dataset (e.g. "breakfast_box"). + Defaults to ``"breakfast_box"``. + image_size (int | tuple[int, int] | None, optional): Size of the input image. + Defaults to ``(256, 256)``. + center_crop (int | tuple[int, int] | None, optional): When provided, the images will be center-cropped + to the provided dimensions. + Defaults to ``None``. + normalization (InputNormalizationMethod | str): Normalization method to be applied to the input images. + Defaults to ``InputNormalizationMethod.IMAGENET``. + train_batch_size (int, optional): Training batch size. + Defaults to ``32``. + eval_batch_size (int, optional): Test batch size. + Defaults to ``32``. + num_workers (int, optional): Number of workers. + Defaults to ``8``. + task TaskType): Task type, 'classification', 'detection' or 'segmentation' + Defaults to ``TaskType.SEGMENTATION``. + transform_config_train (str | A.Compose | None, optional): Config for pre-processing during training. + Defaults to ``None``. + transform_config_val (str | A.Compose | None, optional): Config for pre-processing + during validation. + Defaults to ``None``. + test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. + Defaults to ``TestSplitMode.FROM_DIR``. + test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. + Defaults to ``0.2``. + val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. + Defaults to ``ValSplitMode.FROM_DIR``. + val_split_ratio (float): Fraction of train or test images that will be reserved for validation. + Defaults to ``0.5``. + seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. + Defaults to ``None``. + + Examples: + To create an MVTec LOCO AD datamodule with default settings: + + >>> datamodule = MVTecLoco(root="anomalib/datasets/MVTec_LOCO") + >>> datamodule.setup() + >>> i, data = next(enumerate(datamodule.train_dataloader())) + >>> data.keys() + dict_keys(['image_path', 'label', 'image', 'mask_path', 'mask']) + + >>> data["image"].shape + torch.Size([32, 3, 256, 256]) + + To change the category of the dataset: + + >>> datamodule = MVTecLoco(category="pushpins") + + To change the image and batch size: + + >>> datamodule = MVTecLoco(image_size=(512, 512), train_batch_size=16, eval_batch_size=8) + + MVTec LOCO AD dataset provide an independent validation set with normal images only in the 'validation' folder. + If you would like to use a different validation set splitted from train or test set, + you can use the ``val_split_mode`` and ``val_split_ratio`` arguments to create a new validation set. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.FROM_TEST, val_split_ratio=0.1) + + This will subsample the test set by 10% and use it as the validation set. + If you would like to create a validation set synthetically that would + not change the test set, you can use the ``ValSplitMode.SYNTHETIC`` option. + + >>> datamodule = MVTecLoco(val_split_mode=ValSplitMode.SYNTHETIC, val_split_ratio=0.2) + """ + + def __init__( + self, + root: Path | str = "./datasets/MVTec_LOCO", + category: str = "breakfast_box", + image_size: int | tuple[int, int] = (256, 256), + center_crop: int | tuple[int, int] | None = None, + normalization: InputNormalizationMethod | str = InputNormalizationMethod.IMAGENET, + train_batch_size: int = 32, + eval_batch_size: int = 32, + num_workers: int = 8, + task: TaskType = TaskType.SEGMENTATION, + transform_config_train: str | A.Compose | None = None, + transform_config_eval: str | A.Compose | None = None, + test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR, + test_split_ratio: float = 0.2, + val_split_mode: ValSplitMode = ValSplitMode.FROM_DIR, + val_split_ratio: float = 0.5, + seed: int | None = None, + ) -> None: + super().__init__( + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + num_workers=num_workers, + test_split_mode=test_split_mode, + test_split_ratio=test_split_ratio, + val_split_mode=val_split_mode, + val_split_ratio=val_split_ratio, + seed=seed, + ) + + self.root = Path(root) + self.category = Path(category) + + transform_train = get_transforms( + config=transform_config_train, + image_size=image_size, + center_crop=center_crop, + normalization=InputNormalizationMethod(normalization), + ) + transform_eval = get_transforms( + config=transform_config_eval, + image_size=image_size, + center_crop=center_crop, + normalization=InputNormalizationMethod(normalization), + ) + + self.train_data = MVTecLocoDataset( + task=task, + transform=transform_train, + split=Split.TRAIN, + root=root, + category=category, + ) + self.val_data = MVTecLocoDataset( + task=task, + transform=transform_eval, + split=Split.VAL, + root=root, + category=category, + ) + self.test_data = MVTecLocoDataset( + task=task, + transform=transform_eval, + split=Split.TEST, + root=root, + category=category, + ) + + def prepare_data(self) -> None: + """Download the dataset if not available. + + This method checks if the specified dataset is available in the file system. + If not, it downloads and extracts the dataset into the appropriate directory. + + Example: + Assume the dataset is not available on the file system. + Here's how the directory structure looks before and after calling the + `prepare_data` method: + + Before: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + └── dataset2 + + Calling the method: + + .. code-block:: python + + >> datamodule = MVTecLoco(root="./datasets/MVTec_LOCO", category="breakfast_box") + >> datamodule.prepare_data() + + After: + + .. code-block:: bash + + $ tree datasets + datasets + ├── dataset1 + ├── dataset2 + └── MVTec_LOCO + ├── breakfast_box + ├── ... + └── splicing_connectors + """ + if (self.root / self.category).is_dir(): + logger.info("Found the dataset.") + else: + download_and_extract(self.root, DOWNLOAD_INFO) + + def _setup(self, _stage: str | None = None) -> None: + """Set up the datasets and perform dynamic subset splitting. + + This method overrides the parent class's method to also setup the val dataset. + The MVTec LOCO dataset provides an independent validation subset. + """ + assert self.train_data is not None + assert self.val_data is not None + assert self.test_data is not None + + self.train_data.setup() + self.val_data.setup() + self.test_data.setup() + + self._create_test_split() + self._create_val_split() From 9d89a68d962d2132ad4f01ab2be9f04edec258c9 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 19:12:07 +0900 Subject: [PATCH 04/63] add the default config file for mvtec loco ad dataset Signed-off-by: Willy Fitra Hendria --- src/configs/data/mvtec_loco.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 src/configs/data/mvtec_loco.yaml diff --git a/src/configs/data/mvtec_loco.yaml b/src/configs/data/mvtec_loco.yaml new file mode 100644 index 0000000000..2f60f00817 --- /dev/null +++ b/src/configs/data/mvtec_loco.yaml @@ -0,0 +1,18 @@ +class_path: anomalib.data.MVTecLoco +init_args: + root: ./datasets/MVTec_LOCO + category: breakfast_box + image_size: [256, 256] + center_crop: null + normalization: imagenet + train_batch_size: 32 + eval_batch_size: 32 + num_workers: 8 + task: SEGMENTATION + transform_config_train: null + transform_config_eval: null + test_split_mode: FROM_DIR + test_split_ratio: 0.2 + val_split_mode: FROM_DIR + val_split_ratio: 0.5 + seed: null From a312be02ae71fc5e1b3cbe8135f4715918e3f05c Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 19:17:05 +0900 Subject: [PATCH 05/63] update initialization files to include MVTec LOCO dataset Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/__init__.py | 3 ++- src/anomalib/data/image/__init__.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anomalib/data/__init__.py b/src/anomalib/data/__init__.py index 532fb5bb7c..854ae4c7c2 100644 --- a/src/anomalib/data/__init__.py +++ b/src/anomalib/data/__init__.py @@ -15,7 +15,7 @@ from .base import AnomalibDataModule, AnomalibDataset from .depth import DepthDataFormat, Folder3D, MVTec3D -from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa +from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa from .predict import PredictDataset from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat @@ -62,6 +62,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule: "Kolektor", "MVTec", "MVTec3D", + "MVTecLoco", "Avenue", "UCSDped", "ShanghaiTech", diff --git a/src/anomalib/data/image/__init__.py b/src/anomalib/data/image/__init__.py index 6b30ac9ac8..4db05b51df 100644 --- a/src/anomalib/data/image/__init__.py +++ b/src/anomalib/data/image/__init__.py @@ -13,6 +13,7 @@ from .folder import Folder from .kolektor import Kolektor from .mvtec import MVTec +from .mvtec_loco import MVTecLoco from .visa import Visa @@ -21,6 +22,7 @@ class ImageDataFormat(str, Enum): MVTEC = "mvtec" MVTEC_3D = "mvtec_3d" + MVTEC_LOCO = "mvtec_loco" BTECH = "btech" KOLEKTOR = "kolektor" FOLDER = "folder" @@ -28,4 +30,4 @@ class ImageDataFormat(str, Enum): VISA = "visa" -__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"] +__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"] From 37a6f6b0a54adf4a875006a6e56e068b411531e4 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 21:52:38 +0900 Subject: [PATCH 06/63] remove unnecessary Path conversion Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 47d6bc64e2..ffe3079236 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -105,7 +105,7 @@ def _merge_gt_mask( # Define the path for the new merged mask _, anomaly_dir, image_filename = mask_path.parts[-3:] - new_mask_path = root / Path(gt_merged_dir) / anomaly_dir / (image_filename + ".png") + new_mask_path = root / gt_merged_dir / anomaly_dir / (image_filename + ".png") # Create the necessary directories if they do not exist new_mask_path.parent.mkdir(parents=True, exist_ok=True) From a733bf76612b7f3bda890f798695e0e767eba78b Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 21:53:35 +0900 Subject: [PATCH 07/63] add mvtec_loco.yaml to the readme documentation of configs Signed-off-by: Willy Fitra Hendria --- src/configs/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/configs/README.md b/src/configs/README.md index bd3ecb618e..4334eea678 100644 --- a/src/configs/README.md +++ b/src/configs/README.md @@ -15,6 +15,7 @@ configs/ │ ├── kolektor.yaml │ ├── mvtec_3d.yaml │ ├── mvtec.yaml +│ ├── mvtec_loco.yaml │ ├── shanghaitec.yaml │ ├── ucsd_ped.yaml │ └── visa.yaml From 38ad1f4c595e766414952c1310e4a100979d3ae0 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 21:54:45 +0900 Subject: [PATCH 08/63] add dummy image generation for mvtec loco dataset Signed-off-by: Willy Fitra Hendria --- tests/helpers/data.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/helpers/data.py b/tests/helpers/data.py index 134a863312..a9a64b05e0 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -216,6 +216,7 @@ def __init__( data_format: DataFormat | str, root: Path | str | None = None, num_train: int = 5, + num_val: int = 5, num_test: int = 5, seed: int | None = None, ) -> None: @@ -230,6 +231,7 @@ def __init__( self.root = Path(mkdtemp() if root is None else root) self.dataset_root = self.root / self.data_format.value self.num_train = num_train + self.num_val = num_val self.num_test = num_test self.rng = np.random.default_rng(seed) @@ -298,6 +300,7 @@ def __init__( normal_category: str = "good", abnormal_category: str = "bad", num_train: int = 5, + num_val: int = 5, num_test: int = 5, image_shape: tuple[int, int] = (256, 256), num_channels: int = 3, @@ -308,6 +311,7 @@ def __init__( data_format=data_format, root=root, num_train=num_train, + num_val=num_val, num_test=num_test, seed=seed, ) @@ -389,6 +393,40 @@ def _generate_dummy_mvtec_3d_dataset(self) -> None: else: self.image_generator.save_image(filename=filename, image=image) + def _generate_dummy_mvtec_loco_dataset(self) -> None: + """Generates dummy MVTec LOCO AD dataset in a temporary directory using the same convention as MVTec LOCO AD.""" + # MVTec LOCO has multiple subcategories within the dataset. + dataset_category = "dummy" + + extension = ".png" + + # Create normal images. + for split in ("train", "validation", "test"): + path = self.dataset_root / dataset_category / split / "good" + if split == "train": + num_images = self.num_train + elif split == "val": + num_images = self.num_val + else: + num_images = self.num_test + for i in range(num_images): + label = LabelName.NORMAL + image_filename = path / f"{i:03}{extension}" + self.image_generator.generate_image(label=label, image_filename=image_filename) + + # Create abnormal test images and masks. + for abnormal_dir in ("logical_anomalies", "structural_anomalies"): + path = self.dataset_root / dataset_category / "test" / abnormal_dir + mask_path = self.dataset_root / dataset_category / "ground_truth" / abnormal_dir + + for i in range(self.num_test): + label = LabelName.ABNORMAL + image_filename = path / f"{i:03}{extension}" + # Here, only one ground-truth mask for each abnormal image is generated + # the structure follows the same convention as MVTec LOCO AD, e.g., image_filename/000.png + mask_filename = mask_path / f"{i:03}/000{extension}" + self.image_generator.generate_image(label, image_filename, mask_filename) + def _generate_dummy_kolektor_dataset(self) -> None: """Generate dummy Kolektor dataset in directory using the same convention as Kolektor AD.""" # Emulating the first two categories of Kolektor dataset. From 05e3558c53ab2e55ffcbe77a923454f83eb18908 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 21:55:06 +0900 Subject: [PATCH 09/63] add unit test for mvtec loco dataset Signed-off-by: Willy Fitra Hendria --- tests/unit/data/image/test_mvtec_loco.py | 32 ++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/unit/data/image/test_mvtec_loco.py diff --git a/tests/unit/data/image/test_mvtec_loco.py b/tests/unit/data/image/test_mvtec_loco.py new file mode 100644 index 0000000000..f58355c8e4 --- /dev/null +++ b/tests/unit/data/image/test_mvtec_loco.py @@ -0,0 +1,32 @@ +"""Unit Tests - MVTecLoco Datamodule.""" + +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pytest + +from anomalib import TaskType +from anomalib.data import MVTecLoco +from tests.unit.data.base.image import _TestAnomalibImageDatamodule + + +class TestMVTecLoco(_TestAnomalibImageDatamodule): + """MVTecLoco Datamodule Unit Tests.""" + + @pytest.fixture() + def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTecLoco: + """Create and return a MVTecLoco datamodule.""" + _datamodule = MVTecLoco( + root=dataset_path / "mvtec_loco", + category="dummy", + task=task_type, + image_size=256, + train_batch_size=4, + eval_batch_size=4, + ) + _datamodule.prepare_data() + _datamodule.setup() + + return _datamodule From 519780a9400163d2a58a02d6d01be2d688cf69aa Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sat, 13 Jan 2024 21:58:58 +0900 Subject: [PATCH 10/63] update changelog to include the addition of mvtec loco dataset Signed-off-by: Willy Fitra Hendria --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7af5aec234..c5c6a86022 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added +- Add support for MVTec LOCO AD dataset by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1635 + ### Changed - Changed default inference device to AUTO in https://github.com/openvinotoolkit/anomalib/pull/1534 From c3c6a394f88b90e1598d2a4fc97be4ea1aae4e55 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 14 Jan 2024 11:55:32 +0900 Subject: [PATCH 11/63] add mvtec loco dataset to the sphinx-based docs Signed-off-by: Willy Fitra Hendria --- docs/source/markdown/guides/reference/data/image/index.md | 8 ++++++++ .../markdown/guides/reference/data/image/mvtec_loco.md | 7 +++++++ 2 files changed, 15 insertions(+) create mode 100644 docs/source/markdown/guides/reference/data/image/mvtec_loco.md diff --git a/docs/source/markdown/guides/reference/data/image/index.md b/docs/source/markdown/guides/reference/data/image/index.md index 2525d0d914..ee75fa0daf 100644 --- a/docs/source/markdown/guides/reference/data/image/index.md +++ b/docs/source/markdown/guides/reference/data/image/index.md @@ -30,6 +30,13 @@ Learn more about Kolektor dataset. Learn more about MVTec 2D dataset ::: +:::{grid-item-card} MVTec LOCO +:link: ./mvtec_loco +:link-type: doc + +Learn more about MVTec LOCO dataset +::: + :::{grid-item-card} Visa :link: ./visa :link-type: doc @@ -47,5 +54,6 @@ Learn more about Visa dataset. ./folder ./kolektor ./mvtec +./mvtec_loco ./visa ``` diff --git a/docs/source/markdown/guides/reference/data/image/mvtec_loco.md b/docs/source/markdown/guides/reference/data/image/mvtec_loco.md new file mode 100644 index 0000000000..8854f27b80 --- /dev/null +++ b/docs/source/markdown/guides/reference/data/image/mvtec_loco.md @@ -0,0 +1,7 @@ +# MVTec LOCO Data + +```{eval-rst} +.. automodule:: anomalib.data.image.mvtec_loco + :members: + :show-inheritance: +``` From ca91b8ea1ce97807bffa0e0cf36d393fde53b660 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 14 Jan 2024 11:56:13 +0900 Subject: [PATCH 12/63] fix the malformed table Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index ffe3079236..553541451b 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -134,7 +134,7 @@ def make_mvtec_loco_dataset( +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ | | path | split | label | image_path | mask_path | label_index | +===+===============+=======+=========+===============+=======================================+=============+ - | 0 | datasets/name | test | defect | filename.png | path/to/merged_masks/filename.png | 1 | + | 0 | datasets/name | test | defect | filename.png | path/to/merged_masks/filename.png | 1 | +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ Note: the final image_path is converted to full path by combining it with the path, split, and label columns From 37f4dbce5abafde31c95f961dacdc43cfe6e4ff8 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 15 Jan 2024 08:29:31 +0900 Subject: [PATCH 13/63] binarize the masks and avoid the possibility of the merge_mask is None Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 553541451b..c207505c43 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -100,8 +100,11 @@ def _merge_gt_mask( merged_mask = np.zeros_like(mask) merged_mask = np.maximum(merged_mask, mask) + if merged_mask is None: + continue + # Binarize masks - merged_mask = np.minimum(merged_mask, 255) + merged_mask = np.where(merged_mask > 0, 255, 0) # Define the path for the new merged mask _, anomaly_dir, image_filename = mask_path.parts[-3:] From 3559a7cac8458b972ed835c3f94c6f98049e10f0 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 21 Jan 2024 23:38:06 +0900 Subject: [PATCH 14/63] Merge the masks using sum operation without binarization Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index c207505c43..dfe737cb3c 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -98,14 +98,11 @@ def _merge_gt_mask( mask = cv2.imread(str(mask_file), cv2.IMREAD_UNCHANGED) if merged_mask is None: merged_mask = np.zeros_like(mask) - merged_mask = np.maximum(merged_mask, mask) + merged_mask += mask if merged_mask is None: continue - # Binarize masks - merged_mask = np.where(merged_mask > 0, 255, 0) - # Define the path for the new merged mask _, anomaly_dir, image_filename = mask_path.parts[-3:] new_mask_path = root / gt_merged_dir / anomaly_dir / (image_filename + ".png") From 83539b429479b4e43dff34b754f266a3f68433a6 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:02:41 +0900 Subject: [PATCH 15/63] override getitem method to handle binarization and to add additional 'masks' item Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index dfe737cb3c..cbe4f08798 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -23,6 +23,7 @@ import albumentations as A # noqa: N812 import cv2 import numpy as np +import torch from pandas import DataFrame from anomalib import TaskType @@ -36,6 +37,8 @@ ValSplitMode, download_and_extract, get_transforms, + masks_to_boxes, + read_image, ) logger = logging.getLogger(__name__) @@ -295,6 +298,54 @@ def _setup(self) -> None: gt_merged_dir=GT_MERGED_DIR, ) + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: + """Get dataset item for the index ``index``. + + The implementation of this method is mostly based on the parent's class implementation, with some different as follows: + - Using 'torch.where' to make sure the 'mask' in the return item is binarized + - An additional 'masks' is added to pass the non-binary masks with original size for the sPRO metric calculation + Args: + index (int): Index to get the item. + + Returns: + dict[str, str | torch.Tensor]: Dict of image tensor during training. Otherwise, Dict containing image path, + target path, image tensor, label and transformed bounding box. + """ + image_path = self._samples.iloc[index].image_path + mask_path = self._samples.iloc[index].mask_path + label_index = self._samples.iloc[index].label_index + + image = read_image(image_path) + item = {"image_path": image_path, "label": label_index} + + if self.task == TaskType.CLASSIFICATION: + transformed = self.transform(image=image) + item["image"] = transformed["image"] + elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + # Only Anomalous (1) images have masks in anomaly datasets + # Therefore, create empty mask for Normal (0) images. + mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) + mask = mask.astype(np.single) + + transformed = self.transform(image=image, mask=mask) + + item["image"] = transformed["image"] + item["mask_path"] = mask_path + # transform and binarize the mask + item["mask"] = torch.where(transformed["mask"] > 0, torch.tensor(1.0), torch.tensor(0.0)) + # The non-binary masks with the original size for saturation based metrics calculation + item["masks"] = torch.tensor(mask) + + if self.task == TaskType.DETECTION: + # create boxes from masks for detection task + boxes, _ = masks_to_boxes(item["mask"]) + item["boxes"] = boxes[0] + else: + msg = f"Unknown task type: {self.task}" + raise ValueError(msg) + + return item + class MVTecLoco(AnomalibDataModule): """MVTec LOCO Datamodule. From ddd33c5c4e4a7362e20e4de7b01cad0452f23b26 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:04:38 +0900 Subject: [PATCH 16/63] Add saturation config to the datamodule Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 45 +++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index cbe4f08798..5444e8cf77 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -16,9 +16,11 @@ in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 """ +import json import logging from collections.abc import Sequence from pathlib import Path +from typing import Any import albumentations as A # noqa: N812 import cv2 @@ -63,6 +65,41 @@ GT_MERGED_DIR = "ground_truth_merged" +SATURATION_CONFIG_FILENAME = "defects_config.json" + + +def load_saturation_config(config_path: str | Path) -> dict[int, Any]: + """Load saturation configurations from a JSON file. + + Args: + config_path (str | Path): Path to the saturation configuration file. + + Returns: + Dict: A dictionary with pixel values as keys and the corresponding configurations as values. + + Example JSON format in the file: + [ + { + "defect_name": "1_additional_pushpin", + "pixel_value": 255, + "saturation_threshold": 6300, + "relative_saturation": false + }, + { + "defect_name": "2_additional_pushpins", + "pixel_value": 254, + "saturation_threshold": 12600, + "relative_saturation": false + }, + ... + ] + """ + with Path.open(Path(config_path)) as file: + configs = json.load(file) + + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + def _merge_gt_mask( root: str | Path, @@ -448,9 +485,10 @@ def __init__( val_split_ratio=val_split_ratio, seed=seed, ) - + self.saturation_config: dict[int, Any] self.root = Path(root) self.category = Path(category) + self.saturation_config = {} transform_train = get_transforms( config=transform_config_train, @@ -533,7 +571,7 @@ def prepare_data(self) -> None: download_and_extract(self.root, DOWNLOAD_INFO) def _setup(self, _stage: str | None = None) -> None: - """Set up the datasets and perform dynamic subset splitting. + """Set up the datasets, configs, and perform dynamic subset splitting. This method overrides the parent class's method to also setup the val dataset. The MVTec LOCO dataset provides an independent validation subset. @@ -548,3 +586,6 @@ def _setup(self, _stage: str | None = None) -> None: self._create_test_split() self._create_val_split() + + saturation_path = self.root / self.category / SATURATION_CONFIG_FILENAME + self.saturation_config = load_saturation_config(saturation_path) From 3af6eebcf89a474a7dfb216e193b671ca77593f0 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:10:50 +0900 Subject: [PATCH 17/63] Update the saturation config on the metrics based on the loaded config from the dataset Signed-off-by: Willy Fitra Hendria --- src/anomalib/callbacks/metrics.py | 8 ++++++-- src/anomalib/metrics/collection.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 572ed73099..fa963a82bb 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -67,8 +67,7 @@ def setup( pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. stage (str | None, optional): fit, validate, test or predict. Defaults to None. """ - del trainer, stage # These variables are not used. - + del stage # this variable is not used. image_metric_names = [] if self.image_metric_names is None else self.image_metric_names if isinstance(image_metric_names, str): image_metric_names = [image_metric_names] @@ -98,6 +97,8 @@ def setup( else: pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") self._set_threshold(pl_module) + if hasattr(trainer.datamodule, "saturation_config"): + self._set_saturation_config(pl_module, trainer.datamodule.saturation_config) def on_validation_epoch_start( self, @@ -172,6 +173,9 @@ def _set_threshold(self, pl_module: AnomalyModule) -> None: pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) + def _set_saturation_config(self, pl_module: AnomalyModule, saturation_config: dict[int, Any]) -> None: + pl_module.pixel_metrics.set_saturation_config(saturation_config) + def _update_metrics( self, image_metric: AnomalibMetricCollection, diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index 020aebd9e9..e44c807b12 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -3,16 +3,22 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging + from torchmetrics import MetricCollection +logger = logging.getLogger(__name__) + class AnomalibMetricCollection(MetricCollection): """Extends the MetricCollection class for use in the Anomalib pipeline.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + self._saturation_config: dict self._update_called = False self._threshold = 0.5 + self._saturation_config = {} def set_threshold(self, threshold_value: float) -> None: """Update the threshold value for all metrics that have the threshold attribute.""" @@ -21,6 +27,18 @@ def set_threshold(self, threshold_value: float) -> None: if hasattr(metric, "threshold"): metric.threshold = threshold_value + def set_saturation_config(self, saturation_config: dict) -> None: + """Update the saturation config values for all metrics that have the saturation config attribute.""" + self._saturation_config = saturation_config + for name, metric in self.items(): + if hasattr(metric, "saturation_config"): + metric.saturation_config = saturation_config + else: + logger.warning( + f"Metric {name} may not be suitable for a dataset with the region separated" + "in multiple ground-truth masks.", + ) + def update(self, *args, **kwargs) -> None: """Add data to the metrics.""" super().update(*args, **kwargs) From 53f22971f3848b14140fedc687c7aa8573e39a8b Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:14:15 +0900 Subject: [PATCH 18/63] add masks as a keyword args to the update method of the AnomalibMetricCollection Signed-off-by: Willy Fitra Hendria --- src/anomalib/callbacks/metrics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index fa963a82bb..7dd381102d 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -186,7 +186,11 @@ def _update_metrics( image_metric.update(output["pred_scores"], output["label"].int()) if "mask" in output and "anomaly_maps" in output: pixel_metric.to(self.device) - pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + pixel_metric.update( + torch.squeeze(output["anomaly_maps"]), + torch.squeeze(output["mask"].int()), + masks=torch.squeeze(output["masks"]) if "masks" in output else None, + ) def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: if isinstance(output, dict): From 3d00d69f680cbd7e050cc2d75d6ea02b0c1b79f1 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:48:54 +0900 Subject: [PATCH 19/63] Shorten the comments to solve ruff issues Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 5444e8cf77..329e06333c 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -338,9 +338,9 @@ def _setup(self) -> None: def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: """Get dataset item for the index ``index``. - The implementation of this method is mostly based on the parent's class implementation, with some different as follows: + This method is mostly based on the super class implementation, with some different as follows: - Using 'torch.where' to make sure the 'mask' in the return item is binarized - - An additional 'masks' is added to pass the non-binary masks with original size for the sPRO metric calculation + - An additional 'masks' is added, the non-binary masks with original size for the sPRO metric calculation Args: index (int): Index to get the item. From f024e91c97ecc79fcb3569f3589ae555afb78494 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:51:31 +0900 Subject: [PATCH 20/63] Add sPro metric implementation Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/__init__.py | 2 + src/anomalib/metrics/spro.py | 142 +++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+) create mode 100644 src/anomalib/metrics/spro.py diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 544e6fbf6f..0f43a83ec6 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -19,6 +19,7 @@ from .collection import AnomalibMetricCollection from .min_max import MinMax from .pro import PRO +from .spro import sPRO from .threshold import F1AdaptiveThreshold, ManualThreshold __all__ = [ @@ -30,6 +31,7 @@ "ManualThreshold", "MinMax", "PRO", + "sPRO", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py new file mode 100644 index 0000000000..d8410e57a7 --- /dev/null +++ b/src/anomalib/metrics/spro.py @@ -0,0 +1,142 @@ +"""Implementation of sPRO metric based on TorchMetrics.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +import torch +from torchmetrics import Metric +from torchmetrics.functional import recall +from torchmetrics.utilities.data import dim_zero_cat + +from anomalib.utils.cv import connected_components_cpu, connected_components_gpu + +logger = logging.getLogger(__name__) + +class sPRO(Metric): + """Saturated Per-Region Overlap (sPRO) Score. + + This metric computes the macro average of the saturated per-region overlap between the + predicted anomaly masks and the ground truth masks. + + Args: + threshold (float): Threshold used to binarize the predictions. + Defaults to ``0.5``. + kwargs: Additional arguments to the TorchMetrics base class. + + Example: + Import the metric from the package: + + >>> import torch + >>> from anomalib.metrics import sPRO + + Create random ``preds`` and ``labels`` tensors: + + >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) + >>> preds = torch.rand_like(labels) + + Compute the sPRO score for labels and preds: + + >>> spro = sPRO(threshold=0.5) + >>> spro.update(preds, _, labels) + >>> spro.compute() + tensor(0.6333) + + .. note:: + Note that the example above shows random predictions and labels. + Therefore, the sPRO score above may not be reproducible. + + """ + + targets: list[torch.Tensor] + preds: list[torch.Tensor] + saturation_config: dict + + def __init__(self, threshold: float = 0.5, **kwargs) -> None: + super().__init__(**kwargs) + self.threshold = threshold + self.saturation_config = {} + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("targets", default=[], dist_reduce_fx="cat") + + def update(self, predictions: torch.Tensor, _: torch.Tensor, masks: torch.Tensor) -> None: + """Compute the sPRO score for the current batch. + + Args: + predictions (torch.Tensor): Predicted anomaly masks + _ (torch.Tensor): Unused argument, but needed for different metrics within the same AnomalibMetricCollection + masks (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width + + Example: + To update the metric state for the current batch, use the ``update`` method: + + >>> spro.update(preds, _, labels) + """ + assert masks is not None + self.targets.append(masks) + self.preds.append(predictions) + + def compute(self) -> torch.Tensor: + """Compute the macro average of the sPRO score across all masks in all batches. + + Example: + To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: + + >>> spro.compute() + tensor(0.5433) + """ + targets = dim_zero_cat(self.targets) + preds = dim_zero_cat(self.preds) + return spro_score(preds, targets, threshold=self.threshold, saturation_config=self.saturation_config) + +def spro_score(predictions: torch.Tensor, targets: torch.Tensor, + threshold: float = 0.5, saturation_config: dict = {}) -> torch.Tensor: + """Calculate the sPRO score for a batch of predictions. + + Args: + predictions (torch.Tensor): Predicted anomaly masks + targets: (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + saturation_config (dict): Saturations configuration for each label (pixel value) as the keys + + Returns: + torch.Tensor: Scalar value representing the average sPRO score for the input batch. + """ + predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets.shape[1:]) + + # Apply threshold to binary predictions + if predictions.dtype == torch.float: + predictions = predictions > threshold + + score = torch.tensor(0.0) + + # Iterate for each image in the batch + for i, target in enumerate(targets): + unique_labels = torch.unique(target) + + # Iterate for each ground-truth mask per image + for label in unique_labels[1:]: + # Calculate true positive + target_per_label = target == label + true_pos = torch.sum(predictions[i] & target_per_label) + + # Calculate the areas of the ground-truth + defect_areas = torch.sum(target_per_label) + + if len(saturation_config) > 0: + # Adjust saturation threshold based on configuration + saturation_per_label = saturation_config[label.int().item()] + saturation_threshold = torch.minimum(saturation_per_label["saturation_threshold"], defect_areas) + if saturation_per_label["relative_saturation"]: + saturation_threshold *= defect_areas + else: + # Handle case when saturation_config is empty + logger.warning("The saturation_config attribute is empty, the threshold is set to the defect areas." + "This is equivalent to PRO metric but with the 'region' are separated by mask files") + saturation_threshold = defect_areas + + # Update score with minimum of true_pos/saturation_threshold and 1.0 + score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) + + # Calculate the mean score + return torch.mean(score) From 9b8ca3bf3a87af8a1165097f40e1ff8db83182e1 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 00:57:21 +0900 Subject: [PATCH 21/63] Change the saturation threshold to tensor Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/spro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index d8410e57a7..77b7e9b1cb 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -126,7 +126,7 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, if len(saturation_config) > 0: # Adjust saturation threshold based on configuration saturation_per_label = saturation_config[label.int().item()] - saturation_threshold = torch.minimum(saturation_per_label["saturation_threshold"], defect_areas) + saturation_threshold = torch.minimum(torch.tensor(saturation_per_label["saturation_threshold"]), defect_areas) if saturation_per_label["relative_saturation"]: saturation_threshold *= defect_areas else: From 29aaf468d12fc0eea78feb53bd029a3e8368b00a Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Mon, 22 Jan 2024 08:12:31 +0900 Subject: [PATCH 22/63] Handle case with only background/normal images in scoring Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/spro.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 77b7e9b1cb..b383805969 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -109,7 +109,7 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, predictions = predictions > threshold score = torch.tensor(0.0) - + m = 0 # Iterate for each image in the batch for i, target in enumerate(targets): unique_labels = torch.unique(target) @@ -137,6 +137,9 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) + m += 1 - # Calculate the mean score - return torch.mean(score) + # If there are only backgrounds + if m == 0: + return torch.tensor(1.0) + return score / m From f43275387a6c46749400f5f3d4ac0ef33909ed2f Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 07:57:39 +0900 Subject: [PATCH 23/63] rename spro metric and change the default value of saturation_config to None Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 2 +- src/anomalib/metrics/__init__.py | 4 +- src/anomalib/metrics/spro.py | 61 ++++++++++++++++----------- 3 files changed, 40 insertions(+), 27 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 329e06333c..540e5e0c88 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -340,7 +340,7 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: This method is mostly based on the super class implementation, with some different as follows: - Using 'torch.where' to make sure the 'mask' in the return item is binarized - - An additional 'masks' is added, the non-binary masks with original size for the sPRO metric calculation + - An additional 'masks' is added, the non-binary masks with original size for the SPRO metric calculation Args: index (int): Index to get the item. diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 0f43a83ec6..b0106b6064 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -19,7 +19,7 @@ from .collection import AnomalibMetricCollection from .min_max import MinMax from .pro import PRO -from .spro import sPRO +from .spro import SPRO from .threshold import F1AdaptiveThreshold, ManualThreshold __all__ = [ @@ -31,7 +31,7 @@ "ManualThreshold", "MinMax", "PRO", - "sPRO", + "SPRO", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index b383805969..ac8ae32ceb 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -1,20 +1,19 @@ -"""Implementation of sPRO metric based on TorchMetrics.""" +"""Implementation of SPRO metric based on TorchMetrics.""" # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import logging + import torch from torchmetrics import Metric -from torchmetrics.functional import recall from torchmetrics.utilities.data import dim_zero_cat -from anomalib.utils.cv import connected_components_cpu, connected_components_gpu - logger = logging.getLogger(__name__) -class sPRO(Metric): - """Saturated Per-Region Overlap (sPRO) Score. + +class SPRO(Metric): + """Saturated Per-Region Overlap (SPRO) Score. This metric computes the macro average of the saturated per-region overlap between the predicted anomaly masks and the ground truth masks. @@ -22,45 +21,47 @@ class sPRO(Metric): Args: threshold (float): Threshold used to binarize the predictions. Defaults to ``0.5``. + saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. kwargs: Additional arguments to the TorchMetrics base class. Example: Import the metric from the package: >>> import torch - >>> from anomalib.metrics import sPRO + >>> from anomalib.metrics import SPRO Create random ``preds`` and ``labels`` tensors: >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) >>> preds = torch.rand_like(labels) - Compute the sPRO score for labels and preds: + Compute the SPRO score for labels and preds: - >>> spro = sPRO(threshold=0.5) + >>> spro = SPRO(threshold=0.5) >>> spro.update(preds, _, labels) >>> spro.compute() tensor(0.6333) .. note:: Note that the example above shows random predictions and labels. - Therefore, the sPRO score above may not be reproducible. + Therefore, the SPRO score above may not be reproducible. """ targets: list[torch.Tensor] preds: list[torch.Tensor] - saturation_config: dict - def __init__(self, threshold: float = 0.5, **kwargs) -> None: + def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None, **kwargs) -> None: super().__init__(**kwargs) self.threshold = threshold - self.saturation_config = {} + self.saturation_config = saturation_config self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("targets", default=[], dist_reduce_fx="cat") def update(self, predictions: torch.Tensor, _: torch.Tensor, masks: torch.Tensor) -> None: - """Compute the sPRO score for the current batch. + """Compute the SPRO score for the current batch. Args: predictions (torch.Tensor): Predicted anomaly masks @@ -77,7 +78,7 @@ def update(self, predictions: torch.Tensor, _: torch.Tensor, masks: torch.Tensor self.preds.append(predictions) def compute(self) -> torch.Tensor: - """Compute the macro average of the sPRO score across all masks in all batches. + """Compute the macro average of the SPRO score across all masks in all batches. Example: To compute the metric based on the state accumulated from multiple batches, use the ``compute`` method: @@ -89,18 +90,25 @@ def compute(self) -> torch.Tensor: preds = dim_zero_cat(self.preds) return spro_score(preds, targets, threshold=self.threshold, saturation_config=self.saturation_config) -def spro_score(predictions: torch.Tensor, targets: torch.Tensor, - threshold: float = 0.5, saturation_config: dict = {}) -> torch.Tensor: - """Calculate the sPRO score for a batch of predictions. + +def spro_score( + predictions: torch.Tensor, + targets: torch.Tensor, + threshold: float = 0.5, + saturation_config: dict | None = None, +) -> torch.Tensor: + """Calculate the SPRO score for a batch of predictions. Args: predictions (torch.Tensor): Predicted anomaly masks targets: (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. - saturation_config (dict): Saturations configuration for each label (pixel value) as the keys + saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are + separated by mask files. Returns: - torch.Tensor: Scalar value representing the average sPRO score for the input batch. + torch.Tensor: Scalar value representing the average SPRO score for the input batch. """ predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets.shape[1:]) @@ -123,16 +131,21 @@ def spro_score(predictions: torch.Tensor, targets: torch.Tensor, # Calculate the areas of the ground-truth defect_areas = torch.sum(target_per_label) - if len(saturation_config) > 0: + if saturation_config is not None: # Adjust saturation threshold based on configuration saturation_per_label = saturation_config[label.int().item()] - saturation_threshold = torch.minimum(torch.tensor(saturation_per_label["saturation_threshold"]), defect_areas) + saturation_threshold = torch.minimum( + torch.tensor(saturation_per_label["saturation_threshold"]), + defect_areas, + ) if saturation_per_label["relative_saturation"]: saturation_threshold *= defect_areas else: # Handle case when saturation_config is empty - logger.warning("The saturation_config attribute is empty, the threshold is set to the defect areas." - "This is equivalent to PRO metric but with the 'region' are separated by mask files") + logger.warning( + "The saturation_config attribute is empty, the threshold is set to the defect areas." + "This is equivalent to PRO metric but with the 'region' are separated by mask files", + ) saturation_threshold = defect_areas # Update score with minimum of true_pos/saturation_threshold and 1.0 From 7d348d8c6f340bdba85de1d2991afef52dc6b053 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 07:59:01 +0900 Subject: [PATCH 24/63] add unit test for spro metric Signed-off-by: Willy Fitra Hendria --- tests/unit/metrics/test_spro.py | 70 +++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/unit/metrics/test_spro.py diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py new file mode 100644 index 0000000000..1538645ce3 --- /dev/null +++ b/tests/unit/metrics/test_spro.py @@ -0,0 +1,70 @@ +"""Test SPRO metric.""" + +# Copyright (C) 2023-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from anomalib.metrics.spro import SPRO + +def test_spro() -> None: + """Checks if SPRO metric computes the score utilizing the given saturation configs""" + + saturation_config = { + 255: { + 'saturation_threshold': 10, + 'relative_saturation': False + }, + 254: { + 'saturation_threshold': 0.5, + 'relative_saturation': True + } + } + + masks = torch.Tensor( + [ + [ + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + ] + + ], + ) + + masks[0] *= 255 + masks[1] *= 254 + # merge the multi-mask and add batch dim + merged_masks = (masks[0] + masks[1]).unsqueeze(0) + + preds = (torch.arange(8) / 10) + 0.05 + # metrics receive squeezed predictions (N, H, W) + preds = preds.unsqueeze(1).repeat(1, 5).view(1, 8, 5) + + thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] + targets = [1.0, 1.0, 1.0, 0.75, 0.0, 0.0] + targets_wo_saturation = [1.0, 0.625, 0.5, 0.375, 0.0, 0.0] + for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True): + # test using saturation_cofig + spro = SPRO(threshold=threshold, saturation_config=saturation_config) + spro.update(preds, None, merged_masks) + assert spro.compute() == target + + # test without saturation_config + spro_wo_saturaton = SPRO(threshold=threshold) + spro_wo_saturaton.update(preds, None, merged_masks) + assert spro_wo_saturaton.compute() == target_wo_saturation \ No newline at end of file From 5597bfc298b4cec2bf020a24a0dc879c57c9ec32 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 08:47:46 +0900 Subject: [PATCH 25/63] fix pre-commit issues Signed-off-by: Willy Fitra Hendria --- tests/unit/metrics/test_spro.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py index 1538645ce3..36847d9bd0 100644 --- a/tests/unit/metrics/test_spro.py +++ b/tests/unit/metrics/test_spro.py @@ -4,22 +4,23 @@ # SPDX-License-Identifier: Apache-2.0 import torch + from anomalib.metrics.spro import SPRO + def test_spro() -> None: - """Checks if SPRO metric computes the score utilizing the given saturation configs""" - + """Checks if SPRO metric computes the score utilizing the given saturation configs.""" saturation_config = { 255: { - 'saturation_threshold': 10, - 'relative_saturation': False - }, + "saturation_threshold": 10, + "relative_saturation": False, + }, 254: { - 'saturation_threshold': 0.5, - 'relative_saturation': True - } - } - + "saturation_threshold": 0.5, + "relative_saturation": True, + }, + } + masks = torch.Tensor( [ [ @@ -41,8 +42,7 @@ def test_spro() -> None: [0, 0, 0, 0, 0], [1, 1, 1, 1, 1], [0, 0, 0, 0, 0], - ] - + ], ], ) @@ -67,4 +67,4 @@ def test_spro() -> None: # test without saturation_config spro_wo_saturaton = SPRO(threshold=threshold) spro_wo_saturaton.update(preds, None, merged_masks) - assert spro_wo_saturaton.compute() == target_wo_saturation \ No newline at end of file + assert spro_wo_saturaton.compute() == target_wo_saturation From 7b02863cfaa03c17492d64387947c1039ff41ab1 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 08:53:47 +0900 Subject: [PATCH 26/63] handle file not found error when loading saturation config Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 540e5e0c88..60e97f7eac 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -68,14 +68,15 @@ SATURATION_CONFIG_FILENAME = "defects_config.json" -def load_saturation_config(config_path: str | Path) -> dict[int, Any]: +def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: """Load saturation configurations from a JSON file. Args: config_path (str | Path): Path to the saturation configuration file. Returns: - Dict: A dictionary with pixel values as keys and the corresponding configurations as values. + Dict |None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. Example JSON format in the file: [ @@ -94,11 +95,15 @@ def load_saturation_config(config_path: str | Path) -> dict[int, Any]: ... ] """ - with Path.open(Path(config_path)) as file: - configs = json.load(file) + try: + with Path.open(Path(config_path)) as file: + configs = json.load(file) - # Create a dictionary with pixel values as keys - return {conf["pixel_value"]: conf for conf in configs} + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + except FileNotFoundError: + logger.warning("The saturation config file %s does not exist. Returning None.", config_path) + return None def _merge_gt_mask( @@ -485,7 +490,7 @@ def __init__( val_split_ratio=val_split_ratio, seed=seed, ) - self.saturation_config: dict[int, Any] + self.saturation_config: dict[int, Any] | None self.root = Path(root) self.category = Path(category) self.saturation_config = {} From 6048a086302af9de93a7a324499fcd6132a05fe1 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 21:22:46 +0900 Subject: [PATCH 27/63] validate path before processing Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 60e97f7eac..0e9717478f 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -41,6 +41,7 @@ get_transforms, masks_to_boxes, read_image, + validate_path, ) logger = logging.getLogger(__name__) @@ -96,9 +97,9 @@ def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: ] """ try: - with Path.open(Path(config_path)) as file: + config_path = validate_path(config_path) + with Path.open(config_path) as file: configs = json.load(file) - # Create a dictionary with pixel values as keys return {conf["pixel_value"]: conf for conf in configs} except FileNotFoundError: @@ -205,7 +206,7 @@ def make_mvtec_loco_dataset( >>> path = root / category >>> samples = make_mvtec_loco_dataset(path, split='test') """ - root = Path(root) + root = validate_path(root) gt_merged_dir = Path(gt_merged_dir) # assert the directory to store the merged ground-truth masks is different than the original gt directory From e23734785e2c81c50a96597dbe3d05f87c9c1339 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Thu, 1 Feb 2024 21:24:41 +0900 Subject: [PATCH 28/63] update changelog with new PR Signed-off-by: Willy Fitra Hendria --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5c6a86022..5c42cd9e2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added -- Add support for MVTec LOCO AD dataset by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1635 +- Add support for MVTec LOCO AD dataset and sPRO metric by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1686 ### Changed From f9b67b83033b567283b037e486ba8bcb6ea981f2 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Tue, 6 Feb 2024 20:27:46 +0900 Subject: [PATCH 29/63] Update src/anomalib/data/image/mvtec_loco.py Co-authored-by: Samet Akcay Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 0e9717478f..8b71ce5b91 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -76,8 +76,8 @@ def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: config_path (str | Path): Path to the saturation configuration file. Returns: - Dict |None: A dictionary with pixel values as keys and the corresponding configurations as values. - Return None if the config file is not found. + Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. Example JSON format in the file: [ From 63bf8def6a85cd5cf5c57b5d25ed900001564d58 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Tue, 6 Feb 2024 20:37:52 +0900 Subject: [PATCH 30/63] Update src/anomalib/data/image/mvtec_loco.py Co-authored-by: Samet Akcay Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 8b71ce5b91..25373493d9 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -330,7 +330,7 @@ def __init__( ) -> None: super().__init__(task=task, transform=transform) - self.root_category = Path(root) / Path(category) + self.root_category = Path(root) / category self.split = split def _setup(self) -> None: From 0310e1b8f0f418bd44c9892882d2b1225fca8227 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Tue, 6 Feb 2024 20:38:48 +0900 Subject: [PATCH 31/63] Update tests/helpers/data.py Co-authored-by: Samet Akcay Signed-off-by: Willy Fitra Hendria --- tests/helpers/data.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/helpers/data.py b/tests/helpers/data.py index a9a64b05e0..be58c5ba64 100644 --- a/tests/helpers/data.py +++ b/tests/helpers/data.py @@ -409,6 +409,7 @@ def _generate_dummy_mvtec_loco_dataset(self) -> None: num_images = self.num_val else: num_images = self.num_test + for i in range(num_images): label = LabelName.NORMAL image_filename = path / f"{i:03}{extension}" From 4eb2ec3d9a5c3b35b943afdae431e83be7e5a42d Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:11:14 +0900 Subject: [PATCH 32/63] change assert to raise error Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 31 +++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 25373493d9..a8fccb185b 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -257,15 +257,24 @@ def make_mvtec_loco_dataset( "mask_path", ] = mask_samples.image_path.to_numpy() - # assert that the right mask files are associated with the right test images + # validate that the right mask files are associated with the right test images if len(samples.loc[samples.label_index == LabelName.ABNORMAL]): - assert ( - samples.loc[samples.label_index == LabelName.ABNORMAL] - .apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1) - .all() - ), f"Mismatch between anomalous images and ground truth masks. Make sure the mask files in '{gt_merged_dir!s}' \ - folder follow the same naming convention as the anomalous images in the dataset (e.g. image: \ - '000.png', mask: '000.png')." + image_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["image_path"].apply(lambda x: Path(x).stem) + mask_parent_stems = samples.loc[samples.label_index == LabelName.ABNORMAL]["mask_path"].apply( + lambda x: {Path(mask_path).parent.stem for mask_path in x}, + ) + + if not all( + next(iter(mask_stems)) == image_stem + for image_stem, mask_stems in zip(image_stems, mask_parent_stems, strict=True) + ): + error_message = ( + "Mismatch between anomalous images and ground truth masks. " + "Make sure the parent folder of the mask files in 'ground_truth' folder " + "follows the same naming convention as the anomalous images in the dataset " + "(e.g., image: '005.png', mask: '005/000.png')." + ) + raise ValueError(error_message) if split: samples = samples[samples.split == split].reset_index(drop=True) @@ -582,9 +591,9 @@ def _setup(self, _stage: str | None = None) -> None: This method overrides the parent class's method to also setup the val dataset. The MVTec LOCO dataset provides an independent validation subset. """ - assert self.train_data is not None - assert self.val_data is not None - assert self.test_data is not None + if self.train_data is None or self.val_data is None or self.test_data is None: + error_message = "train_data, val_data, and test_data must all be provided" + raise ValueError(error_message) self.train_data.setup() self.val_data.setup() From f3bccb8f21468236315bfb0c6b21198b2f8f0f43 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:15:27 +0900 Subject: [PATCH 33/63] return list of masks instead of merging the multiple masks from the dataset Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/image/mvtec_loco.py | 136 ++++++++------------------ 1 file changed, 42 insertions(+), 94 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index a8fccb185b..65dbbbb726 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -64,8 +64,6 @@ "splicing_connectors", ) -GT_MERGED_DIR = "ground_truth_merged" - SATURATION_CONFIG_FILENAME = "defects_config.json" @@ -107,64 +105,10 @@ def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: return None -def _merge_gt_mask( - root: str | Path, - extensions: Sequence[str] = IMG_EXTENSIONS, - gt_merged_dir: str | Path = GT_MERGED_DIR, -) -> None: - """Merges ground truth masks within specified directories and saves the merged masks. - - Args: - root (str | Path): Root directory containing the 'ground_truth' folder. - extensions (Sequence[str]): Allowed file extensions for ground truth masks. - Default is IMG_EXTENSIONS. - gt_merged_dir (str | Path]): Directory where merged masks will be saved. - Default is GT_MERGED_DIR. - - Returns: - None - - Example: - >>> _merge_gt_mask('path/to/breakfast_box/') - - This function reads ground truth masks from the specified directories, merges them into - a single mask for each corresponding images (e.g. merge 059/000.png and 059/001.png into 059.png), - and saves the merged masks in the default GT_MERGED_DIR directory. - - Note: The merged masks are saved with the same filename structure as the corresponding anomalous image files. - """ - root = Path(root) - gt_mask_paths = {f.parent for f in root.glob("ground_truth/**/*") if f.suffix in extensions} - - for mask_path in gt_mask_paths: - # Merge each mask inside mask_path into a single mask - merged_mask = None - for mask_file in mask_path.glob("*"): - if mask_file.suffix in extensions: - mask = cv2.imread(str(mask_file), cv2.IMREAD_UNCHANGED) - if merged_mask is None: - merged_mask = np.zeros_like(mask) - merged_mask += mask - - if merged_mask is None: - continue - - # Define the path for the new merged mask - _, anomaly_dir, image_filename = mask_path.parts[-3:] - new_mask_path = root / gt_merged_dir / anomaly_dir / (image_filename + ".png") - - # Create the necessary directories if they do not exist - new_mask_path.parent.mkdir(parents=True, exist_ok=True) - - # Save the merged mask - cv2.imwrite(str(new_mask_path), merged_mask) - - def make_mvtec_loco_dataset( root: str | Path, split: str | Split | None = None, extensions: Sequence[str] = IMG_EXTENSIONS, - gt_merged_dir: str | Path = GT_MERGED_DIR, ) -> DataFrame: """Create MVTec LOCO AD samples by parsing the original MVTec LOCO AD data file structure. @@ -174,17 +118,13 @@ def make_mvtec_loco_dataset( where there can be multiple ground-truth masks for the corresponding anomalous images. - This function first merges the multiple ground-truth-masks by executing _merge_gt_mask(), - it then creates a dataframe to store the parsed information based on the following format: + This function creates a dataframe to store the parsed information based on the following format: - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ - | | path | split | label | image_path | mask_path | label_index | + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ + | | path | split | label | image_path | mask_path | label_index | +===+===============+=======+=========+===============+=======================================+=============+ - | 0 | datasets/name | test | defect | filename.png | path/to/merged_masks/filename.png | 1 | - +---+---------------+-------+---------+---------------+---------------------------------------+-------------+ - - Note: the final image_path is converted to full path by combining it with the path, split, and label columns - Example, datasets/name/test/defect/filename.png + | 0 | datasets/name | test | defect | path/to/image/file.png | [path/to/masks/file.png] | 1 | + +---+---------------+-------+---------+-------------------------+-----------------------------+-------------+ Args: root (str | Path): Path to dataset @@ -192,8 +132,6 @@ def make_mvtec_loco_dataset( Defaults to ``None``. extensions (Sequence[str]): List of file extensions to be included in the dataset. Defaults to ``None``. - gt_merged_dir (str | Path]): Directory where merged masks will be saved. - Default is GT_MERGED_DIR. Returns: DataFrame: an output dataframe containing the samples of the dataset. @@ -207,36 +145,27 @@ def make_mvtec_loco_dataset( >>> samples = make_mvtec_loco_dataset(path, split='test') """ root = validate_path(root) - gt_merged_dir = Path(gt_merged_dir) - - # assert the directory to store the merged ground-truth masks is different than the original gt directory - assert gt_merged_dir != "ground_truth" - - # Merge ground-truth masks for each corresponding images and store into the 'gt_merged_dir' folder - if (root / gt_merged_dir).is_dir(): - logger.info(f"Found the directory of the merged ground-truth masks: {root / gt_merged_dir!s}") - else: - logger.info("Merging the multiple ground-truth masks for each corresponding images.") - _merge_gt_mask(root, gt_merged_dir=gt_merged_dir) # Retrieve the image and mask files samples_list = [] for f in root.glob("**/*"): if f.suffix in extensions: parts = f.parts - # Ignore original 'ground_truth' folder because the 'gt_merged_dir' is used instead + # 'ground_truth' and non 'ground_truth' path have a different structure if "ground_truth" not in parts: - split_folder, label_folder, image_path = parts[-3:] - samples_list.append((str(root), split_folder, label_folder, image_path)) + split_folder, label_folder, image_file = parts[-3:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, "", image_path)) + else: + split_folder, label_folder, image_folder, image_file = parts[-4:] + image_path = f"{root}/{split_folder}/{label_folder}/{image_folder}/{image_file}" + samples_list.append((str(root), split_folder, label_folder, image_folder, image_path)) if not samples_list: msg = f"Found 0 images in {root}" raise RuntimeError(msg) - samples = DataFrame(samples_list, columns=["path", "split", "label", "image_path"]) - - # Modify image_path column by converting to absolute path - samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path + samples = DataFrame(samples_list, columns=["path", "split", "label", "image_folder", "image_path"]) # Replace validation to Split.VAL.value in the split column samples["split"] = samples["split"].replace("validation", Split.VAL.value) @@ -246,16 +175,24 @@ def make_mvtec_loco_dataset( samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL samples.label_index = samples.label_index.astype(int) - # separate masks from samples - mask_samples = samples.loc[samples.split == str(gt_merged_dir)].sort_values(by="image_path", ignore_index=True) - samples = samples[samples.split != str(gt_merged_dir)].sort_values(by="image_path", ignore_index=True) + # separate ground-truth masks from samples + mask_samples = samples.loc[samples.split == "ground_truth"].sort_values(by="image_path", ignore_index=True) + samples = samples[samples.split != "ground_truth"].sort_values(by="image_path", ignore_index=True) + + # Group masks and aggregate the path into a list + mask_samples = ( + mask_samples.groupby(["path", "split", "label", "image_folder"])["image_path"] + .agg(list) + .reset_index() + .rename(columns={"image_path": "mask_path"}) + ) # assign mask paths to anomalous test images samples["mask_path"] = "" samples.loc[ (samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), "mask_path", - ] = mask_samples.image_path.to_numpy() + ] = mask_samples.mask_path.to_numpy() # validate that the right mask files are associated with the right test images if len(samples.loc[samples.label_index == LabelName.ABNORMAL]): @@ -347,7 +284,6 @@ def _setup(self) -> None: self.root_category, split=self.split, extensions=IMG_EXTENSIONS, - gt_merged_dir=GT_MERGED_DIR, ) def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: @@ -376,17 +312,29 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): # Only Anomalous (1) images have masks in anomaly datasets # Therefore, create empty mask for Normal (0) images. - mask = np.zeros(shape=image.shape[:2]) if label_index == 0 else cv2.imread(mask_path, flags=0) + if label_index == LabelName.ABNORMAL: + # Read and stack masks + masks = np.stack([cv2.imread(mask_path, flags=0) for mask_path in mask_path]) + + # Merge masks and create binary mask + mask = np.max(masks, axis=0) + mask = np.where(mask > 0, 1, 0) + else: + # create empty mask for Normal (0) images. + mask = np.zeros(shape=image.shape[:2]) + masks = np.expand_dims(mask, axis=0) + mask = mask.astype(np.single) + masks = masks.astype(np.single) transformed = self.transform(image=image, mask=mask) item["image"] = transformed["image"] item["mask_path"] = mask_path # transform and binarize the mask - item["mask"] = torch.where(transformed["mask"] > 0, torch.tensor(1.0), torch.tensor(0.0)) - # The non-binary masks with the original size for saturation based metrics calculation - item["masks"] = torch.tensor(mask) + item["mask"] = transformed["mask"] + # List of masks with the original size for saturation based metrics calculation + item["masks"] = torch.tensor(masks) if self.task == TaskType.DETECTION: # create boxes from masks for detection task From ac1ecb1904f27692903d68f7168dd0dd7534726a Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:16:50 +0900 Subject: [PATCH 34/63] collate masks as a list of tensor to avoid stack error due to unequal number of masks Signed-off-by: Willy Fitra Hendria --- src/anomalib/data/base/datamodule.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 13975b047c..52483d9dad 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -26,7 +26,9 @@ def collate_fn(batch: list) -> dict[str, Any]: """Collate bounding boxes as lists. - Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries. + Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` is exist, + the `mask_path` is also collated as a list since each element in the batch could be unequal. + For all other entries, the default collate function is used. Args: batch (List): list of items in the batch where len(batch) is equal to the batch size. @@ -40,6 +42,10 @@ def collate_fn(batch: list) -> dict[str, Any]: if "boxes" in elem: # collate boxes as list out_dict["boxes"] = [item.pop("boxes") for item in batch] + if "masks" in elem: + # collate masks and mask_path as list + out_dict["masks"] = [item.pop("masks") for item in batch] + out_dict["mask_path"] = [item.pop("mask_path") for item in batch] # collate other data normally out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem}) return out_dict From 3bbc7507ea29f537118728b579e72e04b998f78f Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:20:22 +0900 Subject: [PATCH 35/63] update spro to handle list of masks and remove the _ args Signed-off-by: Willy Fitra Hendria --- src/anomalib/callbacks/metrics.py | 25 ++++++++--- src/anomalib/metrics/collection.py | 7 +++- src/anomalib/metrics/spro.py | 67 +++++++++++++++++------------- 3 files changed, 61 insertions(+), 38 deletions(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 7dd381102d..9b8b0e3bda 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -13,7 +13,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from anomalib import TaskType -from anomalib.metrics import AnomalibMetricCollection, create_metric_collection +from anomalib.metrics import SPRO, AnomalibMetricCollection, create_metric_collection from anomalib.models import AnomalyModule logger = logging.getLogger(__name__) @@ -186,11 +186,10 @@ def _update_metrics( image_metric.update(output["pred_scores"], output["label"].int()) if "mask" in output and "anomaly_maps" in output: pixel_metric.to(self.device) - pixel_metric.update( - torch.squeeze(output["anomaly_maps"]), - torch.squeeze(output["mask"].int()), - masks=torch.squeeze(output["masks"]) if "masks" in output else None, - ) + if "masks" in output: + self._update_pixel_metrics(pixel_metric, output) + else: + pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: if isinstance(output, dict): @@ -198,8 +197,22 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any output[key] = self._outputs_to_device(value) elif isinstance(output, torch.Tensor): output = output.to(self.device) + elif isinstance(output, list): + for i, value in enumerate(output): + output[i] = self._outputs_to_device(value) return output + def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None: + """Handle metric updates when the SPRO metric is used alongside other pixel-level metrics.""" + update = False + for metric in pixel_metric.values(copy_state=False): + if isinstance(metric, SPRO): + metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"]) + else: + metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + update = True + pixel_metric.set_update_called(update) + @staticmethod def _log_metrics(pl_module: AnomalyModule) -> None: """Log computed performance metrics.""" diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index e44c807b12..f05b8ef835 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -18,7 +18,6 @@ def __init__(self, *args, **kwargs) -> None: self._saturation_config: dict self._update_called = False self._threshold = 0.5 - self._saturation_config = {} def set_threshold(self, threshold_value: float) -> None: """Update the threshold value for all metrics that have the threshold attribute.""" @@ -35,10 +34,14 @@ def set_saturation_config(self, saturation_config: dict) -> None: metric.saturation_config = saturation_config else: logger.warning( - f"Metric {name} may not be suitable for a dataset with the region separated" + f"Metric {name} may not be suitable for a dataset with the region separated " "in multiple ground-truth masks.", ) + def set_update_called(self, val: bool) -> None: + """Set the flag indicating whether the update method has been called.""" + self._update_called = val + def update(self, *args, **kwargs) -> None: """Add data to the metrics.""" super().update(*args, **kwargs) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index ac8ae32ceb..2bc6e35ff6 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -7,7 +7,6 @@ import torch from torchmetrics import Metric -from torchmetrics.utilities.data import dim_zero_cat logger = logging.getLogger(__name__) @@ -34,13 +33,14 @@ class SPRO(Metric): Create random ``preds`` and ``labels`` tensors: - >>> labels = torch.randint(low=0, high=2, size=(1, 10, 5), dtype=torch.float32) - >>> preds = torch.rand_like(labels) + >>> labels = torch.randint(low=0, high=2, size=(2, 10, 5), dtype=torch.float32) + >>> labels = [labels] + >>> preds = torch.rand_like(labels[0][:1]) Compute the SPRO score for labels and preds: >>> spro = SPRO(threshold=0.5) - >>> spro.update(preds, _, labels) + >>> spro.update(preds, labels) >>> spro.compute() tensor(0.6333) @@ -57,25 +57,30 @@ def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None super().__init__(**kwargs) self.threshold = threshold self.saturation_config = saturation_config - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("targets", default=[], dist_reduce_fx="cat") + self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - def update(self, predictions: torch.Tensor, _: torch.Tensor, masks: torch.Tensor) -> None: + def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None: """Compute the SPRO score for the current batch. Args: - predictions (torch.Tensor): Predicted anomaly masks - _ (torch.Tensor): Unused argument, but needed for different metrics within the same AnomalibMetricCollection - masks (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width + predictions (torch.Tensor): Predicted anomaly masks. + masks (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. Example: To update the metric state for the current batch, use the ``update`` method: - >>> spro.update(preds, _, labels) + >>> spro.update(preds, labels) """ - assert masks is not None - self.targets.append(masks) - self.preds.append(predictions) + score, total = spro_score( + predictions, + masks, + threshold=self.threshold, + saturation_config=self.saturation_config, + ) + self.score += score + self.total += total def compute(self) -> torch.Tensor: """Compute the macro average of the SPRO score across all masks in all batches. @@ -86,9 +91,9 @@ def compute(self) -> torch.Tensor: >>> spro.compute() tensor(0.5433) """ - targets = dim_zero_cat(self.targets) - preds = dim_zero_cat(self.preds) - return spro_score(preds, targets, threshold=self.threshold, saturation_config=self.saturation_config) + if self.total == 0: # only background/normal images + return torch.Tensor([1.0]) + return self.score / self.total def spro_score( @@ -100,7 +105,7 @@ def spro_score( """Calculate the SPRO score for a batch of predictions. Args: - predictions (torch.Tensor): Predicted anomaly masks + predictions (torch.Tensor): Predicted anomaly masks. targets: (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. @@ -110,22 +115,28 @@ def spro_score( Returns: torch.Tensor: Scalar value representing the average SPRO score for the input batch. """ - predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets.shape[1:]) + # Add batch dim if not exist + if len(predictions.shape) == 2: + predictions = predictions.unsqueeze(0) + + # Resize the prediction to have the same size as the target mask + predictions = torch.nn.functional.interpolate(predictions.unsqueeze(1), targets[0].shape[-2:]) # Apply threshold to binary predictions if predictions.dtype == torch.float: predictions = predictions > threshold score = torch.tensor(0.0) - m = 0 + total = 0 # Iterate for each image in the batch for i, target in enumerate(targets): - unique_labels = torch.unique(target) - # Iterate for each ground-truth mask per image - for label in unique_labels[1:]: + for mask in target: + label = torch.max(mask) + if label == 0: # Skip if only normal/background + continue # Calculate true positive - target_per_label = target == label + target_per_label = mask == label true_pos = torch.sum(predictions[i] & target_per_label) # Calculate the areas of the ground-truth @@ -150,9 +161,5 @@ def spro_score( # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) - m += 1 - - # If there are only backgrounds - if m == 0: - return torch.tensor(1.0) - return score / m + total += 1 + return score, total From e8e7360ebc8d9ed89fcbca8d5e555c66963b95f5 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:21:18 +0900 Subject: [PATCH 36/63] update unit test to use list of masks as the target Signed-off-by: Willy Fitra Hendria --- tests/unit/metrics/test_spro.py | 56 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py index 36847d9bd0..f6f5826419 100644 --- a/tests/unit/metrics/test_spro.py +++ b/tests/unit/metrics/test_spro.py @@ -21,35 +21,35 @@ def test_spro() -> None: }, } - masks = torch.Tensor( - [ + masks = [ + torch.Tensor( [ - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], + [ + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + [ + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 0], + ], ], - [ - [1, 1, 1, 1, 1], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [1, 1, 1, 1, 1], - [0, 0, 0, 0, 0], - ], - ], - ) + ), + ] - masks[0] *= 255 - masks[1] *= 254 - # merge the multi-mask and add batch dim - merged_masks = (masks[0] + masks[1]).unsqueeze(0) + masks[0][0] *= 255 + masks[0][1] *= 254 preds = (torch.arange(8) / 10) + 0.05 # metrics receive squeezed predictions (N, H, W) @@ -61,10 +61,10 @@ def test_spro() -> None: for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True): # test using saturation_cofig spro = SPRO(threshold=threshold, saturation_config=saturation_config) - spro.update(preds, None, merged_masks) + spro.update(preds, masks) assert spro.compute() == target # test without saturation_config spro_wo_saturaton = SPRO(threshold=threshold) - spro_wo_saturaton.update(preds, None, merged_masks) + spro_wo_saturaton.update(preds, masks) assert spro_wo_saturaton.compute() == target_wo_saturation From ccef95c9d47c56e28ef55a96ebe7e6ebb8852e17 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:34:40 +0900 Subject: [PATCH 37/63] update type and docstring of spro_score function Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/spro.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 2bc6e35ff6..3434818adc 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -98,7 +98,7 @@ def compute(self) -> torch.Tensor: def spro_score( predictions: torch.Tensor, - targets: torch.Tensor, + targets: list[torch.Tensor], threshold: float = 0.5, saturation_config: dict | None = None, ) -> torch.Tensor: @@ -106,7 +106,8 @@ def spro_score( Args: predictions (torch.Tensor): Predicted anomaly masks. - targets: (torch.Tensor): Ground truth anomaly masks with non-binary values and, original height and width + targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the + list is a tensor list of masks for the corresponding image. threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are From 2cdeb82eb79348a23325fb3d873862775366faa2 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 00:42:37 +0900 Subject: [PATCH 38/63] remove _saturation_config attribute from metric collection module Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/collection.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index f05b8ef835..27399041dc 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -15,7 +15,6 @@ class AnomalibMetricCollection(MetricCollection): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self._saturation_config: dict self._update_called = False self._threshold = 0.5 @@ -28,7 +27,6 @@ def set_threshold(self, threshold_value: float) -> None: def set_saturation_config(self, saturation_config: dict) -> None: """Update the saturation config values for all metrics that have the saturation config attribute.""" - self._saturation_config = saturation_config for name, metric in self.items(): if hasattr(metric, "saturation_config"): metric.saturation_config = saturation_config From d2cbcf39c5625300339ca3a3228e676d5e2a4fd1 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 13:02:27 +0900 Subject: [PATCH 39/63] remove unnecessary lines Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/spro.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 3434818adc..b566635247 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -50,9 +50,6 @@ class SPRO(Metric): """ - targets: list[torch.Tensor] - preds: list[torch.Tensor] - def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None, **kwargs) -> None: super().__init__(**kwargs) self.threshold = threshold From 829289b80a1b34645d8f054965f8be4f88a71fc4 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 13:33:08 +0900 Subject: [PATCH 40/63] add unit test to make sure the `mask` is binary Signed-off-by: Willy Fitra Hendria --- tests/unit/data/image/test_mvtec_loco.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/unit/data/image/test_mvtec_loco.py b/tests/unit/data/image/test_mvtec_loco.py index f58355c8e4..a322fd6672 100644 --- a/tests/unit/data/image/test_mvtec_loco.py +++ b/tests/unit/data/image/test_mvtec_loco.py @@ -30,3 +30,10 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTecLoco: _datamodule.setup() return _datamodule + + def test_mask_is_binary(self, datamodule: MVTecLoco) -> None: + """Test if the mask tensor is binary.""" + if datamodule.test_data.task in (TaskType.DETECTION, TaskType.SEGMENTATION): + mask_tensor = datamodule.test_data[0]["mask"] + is_binary = (mask_tensor.eq(0) | mask_tensor.eq(1)).all() + assert is_binary.item() is True From 53aa07f9209dd19add239db8b147adf70ff880ed Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 14:04:04 +0900 Subject: [PATCH 41/63] add warning when the saturation threshold is larger than the defect area Signed-off-by: Willy Fitra Hendria --- src/anomalib/metrics/spro.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index b566635247..9128ded884 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -137,25 +137,32 @@ def spro_score( target_per_label = mask == label true_pos = torch.sum(predictions[i] & target_per_label) - # Calculate the areas of the ground-truth - defect_areas = torch.sum(target_per_label) + # Calculate the anomalous area of the ground-truth + defect_area = torch.sum(target_per_label) if saturation_config is not None: # Adjust saturation threshold based on configuration saturation_per_label = saturation_config[label.int().item()] - saturation_threshold = torch.minimum( - torch.tensor(saturation_per_label["saturation_threshold"]), - defect_areas, - ) + saturation_threshold = saturation_per_label["saturation_threshold"] + if saturation_per_label["relative_saturation"]: - saturation_threshold *= defect_areas + saturation_threshold *= defect_area + + # Check if threshold is larger than defect area + if saturation_threshold > defect_area: + warning_msg = ( + f"Saturation threshold for label {label.int().item()} is larger than defect area. " + "Setting it to defect area." + ) + logger.warning(warning_msg) + saturation_threshold = defect_area else: # Handle case when saturation_config is empty logger.warning( - "The saturation_config attribute is empty, the threshold is set to the defect areas." + "The saturation_config attribute is empty, the threshold is set to the defect area." "This is equivalent to PRO metric but with the 'region' are separated by mask files", ) - saturation_threshold = defect_areas + saturation_threshold = defect_area # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) From d9a233340061c70334f726869e4189d33e01c061 Mon Sep 17 00:00:00 2001 From: Willy Fitra Hendria Date: Sun, 11 Feb 2024 21:23:32 +0900 Subject: [PATCH 42/63] Move the loading process of saturation config from dataset to metric Signed-off-by: Willy Fitra Hendria --- src/anomalib/callbacks/metrics.py | 13 +++--- src/anomalib/cli/cli.py | 7 +++- src/anomalib/data/image/mvtec_loco.py | 47 ---------------------- src/anomalib/metrics/collection.py | 11 ----- src/anomalib/metrics/spro.py | 58 +++++++++++++++++++++++---- tests/unit/metrics/test_spro.py | 23 ++++++++--- 6 files changed, 81 insertions(+), 78 deletions(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 9b8b0e3bda..9e2bcdb7ec 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -67,7 +67,7 @@ def setup( pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule. stage (str | None, optional): fit, validate, test or predict. Defaults to None. """ - del stage # this variable is not used. + del stage, trainer # this variable is not used. image_metric_names = [] if self.image_metric_names is None else self.image_metric_names if isinstance(image_metric_names, str): image_metric_names = [image_metric_names] @@ -97,8 +97,6 @@ def setup( else: pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") self._set_threshold(pl_module) - if hasattr(trainer.datamodule, "saturation_config"): - self._set_saturation_config(pl_module, trainer.datamodule.saturation_config) def on_validation_epoch_start( self, @@ -173,9 +171,6 @@ def _set_threshold(self, pl_module: AnomalyModule) -> None: pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) - def _set_saturation_config(self, pl_module: AnomalyModule, saturation_config: dict[int, Any]) -> None: - pl_module.pixel_metrics.set_saturation_config(saturation_config) - def _update_metrics( self, image_metric: AnomalibMetricCollection, @@ -205,10 +200,14 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None: """Handle metric updates when the SPRO metric is used alongside other pixel-level metrics.""" update = False - for metric in pixel_metric.values(copy_state=False): + for name, metric in pixel_metric.items(copy_state=False): if isinstance(metric, SPRO): metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"]) else: + logger.warning( + f"Metric {name} may not be suitable for a dataset with the region separated " + "in multiple ground-truth masks.", + ) metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) update = True pixel_metric.set_update_called(update) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index d56e6d578b..736acbd830 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -143,7 +143,12 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.add_argument("--visualization.show", type=bool, default=False) parser.add_argument("--task", type=TaskType, default=TaskType.SEGMENTATION) parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) - parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) + parser.add_argument( + "--metrics.pixel", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=None, + required=False, + ) parser.add_argument("--metrics.threshold", type=BaseThreshold, default="F1AdaptiveThreshold") parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False) if hasattr(parser, "subcommand") and parser.subcommand != "predict": # Predict also accepts str and Path inputs diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 65dbbbb726..8864ada0cd 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -16,11 +16,9 @@ in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 """ -import json import logging from collections.abc import Sequence from pathlib import Path -from typing import Any import albumentations as A # noqa: N812 import cv2 @@ -64,46 +62,6 @@ "splicing_connectors", ) -SATURATION_CONFIG_FILENAME = "defects_config.json" - - -def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: - """Load saturation configurations from a JSON file. - - Args: - config_path (str | Path): Path to the saturation configuration file. - - Returns: - Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. - Return None if the config file is not found. - - Example JSON format in the file: - [ - { - "defect_name": "1_additional_pushpin", - "pixel_value": 255, - "saturation_threshold": 6300, - "relative_saturation": false - }, - { - "defect_name": "2_additional_pushpins", - "pixel_value": 254, - "saturation_threshold": 12600, - "relative_saturation": false - }, - ... - ] - """ - try: - config_path = validate_path(config_path) - with Path.open(config_path) as file: - configs = json.load(file) - # Create a dictionary with pixel values as keys - return {conf["pixel_value"]: conf for conf in configs} - except FileNotFoundError: - logger.warning("The saturation config file %s does not exist. Returning None.", config_path) - return None - def make_mvtec_loco_dataset( root: str | Path, @@ -448,10 +406,8 @@ def __init__( val_split_ratio=val_split_ratio, seed=seed, ) - self.saturation_config: dict[int, Any] | None self.root = Path(root) self.category = Path(category) - self.saturation_config = {} transform_train = get_transforms( config=transform_config_train, @@ -549,6 +505,3 @@ def _setup(self, _stage: str | None = None) -> None: self._create_test_split() self._create_val_split() - - saturation_path = self.root / self.category / SATURATION_CONFIG_FILENAME - self.saturation_config = load_saturation_config(saturation_path) diff --git a/src/anomalib/metrics/collection.py b/src/anomalib/metrics/collection.py index 27399041dc..47c17a3a44 100644 --- a/src/anomalib/metrics/collection.py +++ b/src/anomalib/metrics/collection.py @@ -25,17 +25,6 @@ def set_threshold(self, threshold_value: float) -> None: if hasattr(metric, "threshold"): metric.threshold = threshold_value - def set_saturation_config(self, saturation_config: dict) -> None: - """Update the saturation config values for all metrics that have the saturation config attribute.""" - for name, metric in self.items(): - if hasattr(metric, "saturation_config"): - metric.saturation_config = saturation_config - else: - logger.warning( - f"Metric {name} may not be suitable for a dataset with the region separated " - "in multiple ground-truth masks.", - ) - def set_update_called(self, val: bool) -> None: """Set the flag indicating whether the update method has been called.""" self._update_called = val diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 9128ded884..9269ce5188 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -3,11 +3,16 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import json import logging +from pathlib import Path +from typing import Any import torch from torchmetrics import Metric +from anomalib.data.utils import validate_path + logger = logging.getLogger(__name__) @@ -20,7 +25,7 @@ class SPRO(Metric): Args: threshold (float): Threshold used to binarize the predictions. Defaults to ``0.5``. - saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. + saturation_config (str | Path): Path to the saturation configuration file. Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are separated by mask files. kwargs: Additional arguments to the TorchMetrics base class. @@ -50,10 +55,15 @@ class SPRO(Metric): """ - def __init__(self, threshold: float = 0.5, saturation_config: dict | None = None, **kwargs) -> None: + def __init__(self, threshold: float = 0.5, saturation_config: str | Path | None = None, **kwargs) -> None: super().__init__(**kwargs) self.threshold = threshold - self.saturation_config = saturation_config + self.saturation_config = load_saturation_config(saturation_config) if saturation_config is not None else None + if self.saturation_config is None: + logger.warning( + "The saturation_config attribute is empty, the threshold is set to the defect area." + "This is equivalent to PRO metric but with the 'region' are separated by mask files", + ) self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") @@ -158,13 +168,47 @@ def spro_score( saturation_threshold = defect_area else: # Handle case when saturation_config is empty - logger.warning( - "The saturation_config attribute is empty, the threshold is set to the defect area." - "This is equivalent to PRO metric but with the 'region' are separated by mask files", - ) saturation_threshold = defect_area # Update score with minimum of true_pos/saturation_threshold and 1.0 score += torch.minimum(true_pos / saturation_threshold, torch.tensor(1.0)) total += 1 return score, total + + +def load_saturation_config(config_path: str | Path) -> dict[int, Any] | None: + """Load saturation configurations from a JSON file. + + Args: + config_path (str | Path): Path to the saturation configuration file. + + Returns: + Dict | None: A dictionary with pixel values as keys and the corresponding configurations as values. + Return None if the config file is not found. + + Example JSON format in the config file of MVTec LOCO dataset: + [ + { + "defect_name": "1_additional_pushpin", + "pixel_value": 255, + "saturation_threshold": 6300, + "relative_saturation": false + }, + { + "defect_name": "2_additional_pushpins", + "pixel_value": 254, + "saturation_threshold": 12600, + "relative_saturation": false + }, + ... + ] + """ + try: + config_path = validate_path(config_path) + with Path.open(config_path) as file: + configs = json.load(file) + # Create a dictionary with pixel values as keys + return {conf["pixel_value"]: conf for conf in configs} + except FileNotFoundError: + logger.warning("The saturation config file %s does not exist. Returning None.", config_path) + return None diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py index f6f5826419..37ee536433 100644 --- a/tests/unit/metrics/test_spro.py +++ b/tests/unit/metrics/test_spro.py @@ -3,6 +3,10 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import json +import pathlib +import tempfile + import torch from anomalib.metrics.spro import SPRO @@ -10,16 +14,22 @@ def test_spro() -> None: """Checks if SPRO metric computes the score utilizing the given saturation configs.""" - saturation_config = { - 255: { + saturation_config = [ + { + "pixel_value": 255, "saturation_threshold": 10, "relative_saturation": False, }, - 254: { + { + "pixel_value": 254, "saturation_threshold": 0.5, "relative_saturation": True, }, - } + ] + + with tempfile.NamedTemporaryFile(suffix=".json", mode="w", delete=False) as f: + json.dump(saturation_config, f) + saturation_config_json = f.name masks = [ torch.Tensor( @@ -60,7 +70,7 @@ def test_spro() -> None: targets_wo_saturation = [1.0, 0.625, 0.5, 0.375, 0.0, 0.0] for threshold, target, target_wo_saturation in zip(thresholds, targets, targets_wo_saturation, strict=True): # test using saturation_cofig - spro = SPRO(threshold=threshold, saturation_config=saturation_config) + spro = SPRO(threshold=threshold, saturation_config=saturation_config_json) spro.update(preds, masks) assert spro.compute() == target @@ -68,3 +78,6 @@ def test_spro() -> None: spro_wo_saturaton = SPRO(threshold=threshold) spro_wo_saturaton.update(preds, masks) assert spro_wo_saturaton.compute() == target_wo_saturation + + # Remove the temporary config file + pathlib.Path(saturation_config_json).unlink() From 8785aa2994ba9b8b052b449009ef842a6f7e8e1c Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 11:18:52 +0200 Subject: [PATCH 43/63] merge main --- src/anomalib/data/image/mvtec_loco.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 4b9ca24c9e..44ef680b93 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -275,7 +275,7 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: item["mask_path"] = mask_path # List of masks with the original size for saturation based metrics calculation - item["original_masks"] = mask + item["masks"] = mask if self.task == TaskType.DETECTION: # create boxes from masks for detection task @@ -398,10 +398,6 @@ def _setup(self, _stage: str | None = None) -> None: This method overrides the parent class's method to also setup the val dataset. The MVTec LOCO dataset provides an independent validation subset. """ - if self.train_data is None or self.val_data is None or self.test_data is None: - error_message = "train_data, val_data, and test_data must all be provided" - raise ValueError(error_message) - self.train_data = MVTecLocoDataset( task=self.task, transform=self.train_transform, From 930cfe2f62437e24dc11df9279340ac31803c5e3 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Tue, 9 Apr 2024 12:24:18 +0100 Subject: [PATCH 44/63] Update src/anomalib/metrics/spro.py --- src/anomalib/metrics/spro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 9269ce5188..a18969767b 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -1,6 +1,6 @@ """Implementation of SPRO metric based on TorchMetrics.""" -# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import json From a2c529cf143dbc61f1a9824ad85139fae929eb10 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:02:34 +0200 Subject: [PATCH 45/63] update hashsum --- src/anomalib/data/image/mvtec_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 44ef680b93..a00e19a292 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -49,7 +49,7 @@ name="mvtec_loco", url="https://www.mydrive.ch/shares/48237/1b9106ccdfbb09a0c414bd49fe44a14a/download/430647091-1646842701" "/mvtec_loco_anomaly_detection.tar.xz", - hashsum="d40f092ac6f88433f609583c4a05f56f", + hashsum="9e7c84dba550fd2e59d8e9e231c929c45ba737b6b6a6d3814100f54d63aae687", ) CATEGORIES = ( From 0c68c900468a638ed7fd95a28984cf597a0a7e04 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:04:34 +0200 Subject: [PATCH 46/63] update example --- src/anomalib/data/image/mvtec_loco.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index a00e19a292..e437c6b420 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -194,8 +194,9 @@ class MVTecLocoDataset(AnomalibDataset): from anomalib.data.image.mvtec_loco import MVTecLocoDataset from anomalib.data.utils.transforms import get_transforms + from torchvision.transforms.v2 import Resize - transform = get_transforms(image_size=256) + transform = Resize((256, 256)) dataset = MVTecLocoDataset( task="classification", transform=transform, From d6bd7fbf6c29a6139a6ca908e3e2f025e0a4cbb0 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:09:08 +0200 Subject: [PATCH 47/63] remove duplicate parameter --- src/anomalib/cli/cli.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 3c14a175e9..50763476d9 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -150,7 +150,6 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: default=None, required=False, ) - parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold") parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False) if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"): From 934f75366883a79d74a8f3d438e915494ecf19b8 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:15:21 +0200 Subject: [PATCH 48/63] Update src/anomalib/data/base/datamodule.py Co-authored-by: Samet Akcay --- src/anomalib/data/base/datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 250858b640..787d4173e6 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -28,7 +28,7 @@ def collate_fn(batch: list) -> dict[str, Any]: """Collate bounding boxes as lists. - Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` is exist, + Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists, the `mask_path` is also collated as a list since each element in the batch could be unequal. For all other entries, the default collate function is used. From 63f210ebe0780f4b5bb23682d9c53868b70438be Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:21:33 +0200 Subject: [PATCH 49/63] move and update loco config --- {src/configs => configs}/data/mvtec_loco.yaml | 5 ----- 1 file changed, 5 deletions(-) rename {src/configs => configs}/data/mvtec_loco.yaml (69%) diff --git a/src/configs/data/mvtec_loco.yaml b/configs/data/mvtec_loco.yaml similarity index 69% rename from src/configs/data/mvtec_loco.yaml rename to configs/data/mvtec_loco.yaml index 2f60f00817..92c04542c0 100644 --- a/src/configs/data/mvtec_loco.yaml +++ b/configs/data/mvtec_loco.yaml @@ -2,15 +2,10 @@ class_path: anomalib.data.MVTecLoco init_args: root: ./datasets/MVTec_LOCO category: breakfast_box - image_size: [256, 256] - center_crop: null - normalization: imagenet train_batch_size: 32 eval_batch_size: 32 num_workers: 8 task: SEGMENTATION - transform_config_train: null - transform_config_eval: null test_split_mode: FROM_DIR test_split_ratio: 0.2 val_split_mode: FROM_DIR From ba371c4ef0bdb30cb2486b47789d8b788e5f8778 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:22:07 +0200 Subject: [PATCH 50/63] Update tests/unit/data/image/test_mvtec_loco.py Co-authored-by: Samet Akcay --- tests/unit/data/image/test_mvtec_loco.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/data/image/test_mvtec_loco.py b/tests/unit/data/image/test_mvtec_loco.py index a322fd6672..64275319fe 100644 --- a/tests/unit/data/image/test_mvtec_loco.py +++ b/tests/unit/data/image/test_mvtec_loco.py @@ -1,6 +1,6 @@ """Unit Tests - MVTecLoco Datamodule.""" -# Copyright (C) 2023 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 from pathlib import Path From 114f8e604d27a7542b6e1f64a94f2e358ecbe0e8 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:22:17 +0200 Subject: [PATCH 51/63] Update tests/unit/metrics/test_spro.py Co-authored-by: Samet Akcay --- tests/unit/metrics/test_spro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/metrics/test_spro.py b/tests/unit/metrics/test_spro.py index 37ee536433..b7dd8e043d 100644 --- a/tests/unit/metrics/test_spro.py +++ b/tests/unit/metrics/test_spro.py @@ -1,6 +1,6 @@ """Test SPRO metric.""" -# Copyright (C) 2023-2024 Intel Corporation +# Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 import json From 7dbb98302039ab91624a7552d91064dd951ab1eb Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:23:00 +0200 Subject: [PATCH 52/63] Update src/anomalib/data/image/mvtec_loco.py Co-authored-by: Samet Akcay --- src/anomalib/data/image/mvtec_loco.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index e437c6b420..19c23178a3 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -16,6 +16,9 @@ in: International Journal of Computer Vision (IJCV) 130, 947-969, 2022, DOI: 10.1007/s11263-022-01578-9 """ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import logging from collections.abc import Sequence from pathlib import Path From 6587cac7b91fb63d2b26a1f2527b14e7360d6736 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:23:33 +0200 Subject: [PATCH 53/63] Update src/anomalib/cli/cli.py Co-authored-by: Ashwin Vaidya --- src/anomalib/cli/cli.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 50763476d9..802dcce35d 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -143,7 +143,11 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: parser.add_function_arguments(get_normalization_callback, "normalization") parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) - parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) + parser.add_argument( + "--metrics.image", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=["F1Score", "AUROC"], + ) parser.add_argument( "--metrics.pixel", type=list[str] | str | dict[str, dict[str, Any]] | None, From 36ebf1c33e17f3bd1e158522c9ba29b615025108 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:27:26 +0200 Subject: [PATCH 54/63] Update src/anomalib/metrics/spro.py Co-authored-by: Samet Akcay --- src/anomalib/metrics/spro.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index a18969767b..0b4eac2d31 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -81,8 +81,8 @@ def update(self, predictions: torch.Tensor, masks: list[torch.Tensor]) -> None: >>> spro.update(preds, labels) """ score, total = spro_score( - predictions, - masks, + predictions=predictions, + targets=masks, threshold=self.threshold, saturation_config=self.saturation_config, ) From b998823e85e84971584051748c77b5b322850033 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 14:27:36 +0200 Subject: [PATCH 55/63] Update src/anomalib/metrics/spro.py Co-authored-by: Samet Akcay --- src/anomalib/metrics/spro.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/anomalib/metrics/spro.py b/src/anomalib/metrics/spro.py index 0b4eac2d31..c59091ee5f 100644 --- a/src/anomalib/metrics/spro.py +++ b/src/anomalib/metrics/spro.py @@ -116,6 +116,7 @@ def spro_score( targets: (list[torch.Tensor]): Ground truth anomaly masks with original height and width. Each element in the list is a tensor list of masks for the corresponding image. threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. + Defaults: ``0.5``. saturation_config (dict): Saturations configuration for each label (pixel value) as the keys. Defaults: ``None`` (which the score is equivalent to PRO metric, but with the 'region' are separated by mask files. From c844c1d4f992d19a0ac9a7f4b1b5950b9c5d64b0 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 16:30:41 +0200 Subject: [PATCH 56/63] fix mask loading --- src/anomalib/data/image/mvtec_loco.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 19c23178a3..fc8771d54d 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -270,11 +270,14 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: elif self.task in (TaskType.DETECTION, TaskType.SEGMENTATION): # Only Anomalous (1) images have masks in anomaly datasets # Therefore, create empty mask for Normal (0) images. + if isinstance(mask_path, str): + mask_path = [mask_path] mask = ( Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) if label_index == LabelName.NORMAL - else read_mask(mask_path, as_tensor=True) + else Mask(torch.stack([read_mask(path, as_tensor=True) for path in mask_path])) ) + mask = Mask(mask.view(-1, *mask.shape[-2:]).any(dim=0).to(torch.uint8)) item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask) item["mask_path"] = mask_path From 58c3b3f9c1ffeaed8e844cb51d4688021e298d98 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 16:50:37 +0200 Subject: [PATCH 57/63] ruff --- src/anomalib/cli/cli.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anomalib/cli/cli.py b/src/anomalib/cli/cli.py index 802dcce35d..ccb8d64748 100644 --- a/src/anomalib/cli/cli.py +++ b/src/anomalib/cli/cli.py @@ -144,10 +144,10 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None: parser.add_function_arguments(get_normalization_callback, "normalization") parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) parser.add_argument( - "--metrics.image", - type=list[str] | str | dict[str, dict[str, Any]] | None, - default=["F1Score", "AUROC"], - ) + "--metrics.image", + type=list[str] | str | dict[str, dict[str, Any]] | None, + default=["F1Score", "AUROC"], + ) parser.add_argument( "--metrics.pixel", type=list[str] | str | dict[str, dict[str, Any]] | None, From 4028b9ab8484bb82fecfa2b4464842686c83d6cb Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 9 Apr 2024 17:58:27 +0200 Subject: [PATCH 58/63] fix multiple mask loading --- src/anomalib/data/image/mvtec_loco.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index fc8771d54d..2ad5221f88 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -272,17 +272,17 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: # Therefore, create empty mask for Normal (0) images. if isinstance(mask_path, str): mask_path = [mask_path] - mask = ( + masks = ( Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) if label_index == LabelName.NORMAL else Mask(torch.stack([read_mask(path, as_tensor=True) for path in mask_path])) ) - mask = Mask(mask.view(-1, *mask.shape[-2:]).any(dim=0).to(torch.uint8)) + mask = Mask(masks.view(-1, *masks.shape[-2:]).any(dim=0).to(torch.uint8)) item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask) item["mask_path"] = mask_path # List of masks with the original size for saturation based metrics calculation - item["masks"] = mask + item["masks"] = masks if self.task == TaskType.DETECTION: # create boxes from masks for detection task From c87f742cdc0e4780c00dd03270316903d3fd1d1f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 10 Apr 2024 13:17:35 +0200 Subject: [PATCH 59/63] masks -> semantic_mask --- src/anomalib/callbacks/metrics.py | 4 ++-- src/anomalib/data/base/datamodule.py | 8 +++++--- src/anomalib/data/image/mvtec_loco.py | 19 +++++++++++++------ 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 24cb8a914f..57ab0b3a27 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -181,7 +181,7 @@ def _update_metrics( image_metric.update(output["pred_scores"], output["label"].int()) if "mask" in output and "anomaly_maps" in output: pixel_metric.to(self.device) - if "masks" in output: + if "semantic_mask" in output: self._update_pixel_metrics(pixel_metric, output) else: pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) @@ -202,7 +202,7 @@ def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: update = False for name, metric in pixel_metric.items(copy_state=False): if isinstance(metric, SPRO): - metric.update(torch.squeeze(output["anomaly_maps"]), output["masks"]) + metric.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) else: logger.warning( f"Metric {name} may not be suitable for a dataset with the region separated " diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 787d4173e6..6415a36a2f 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -44,9 +44,11 @@ def collate_fn(batch: list) -> dict[str, Any]: if "boxes" in elem: # collate boxes as list out_dict["boxes"] = [item.pop("boxes") for item in batch] - if "masks" in elem: - # collate masks and mask_path as list - out_dict["masks"] = [item.pop("masks") for item in batch] + if "semantic_mask" in elem: + # semantic masks have a variable number of channels, so we collate them as a list + out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch] + if "mask_path" in elem and isinstance(elem["mask_path"], list): + # collate mask paths as list out_dict["mask_path"] = [item.pop("mask_path") for item in batch] # collate other data normally out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem}) diff --git a/src/anomalib/data/image/mvtec_loco.py b/src/anomalib/data/image/mvtec_loco.py index 2ad5221f88..4ef2b4ab8e 100644 --- a/src/anomalib/data/image/mvtec_loco.py +++ b/src/anomalib/data/image/mvtec_loco.py @@ -25,7 +25,9 @@ import torch from pandas import DataFrame +from PIL import Image from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2.functional import to_image from torchvision.tv_tensors import Mask from anomalib import TaskType @@ -39,7 +41,6 @@ download_and_extract, masks_to_boxes, read_image, - read_mask, validate_path, ) @@ -245,6 +246,11 @@ def __init__( extensions=IMG_EXTENSIONS, ) + @staticmethod + def _read_mask(mask_path: str | Path) -> Mask: + image = Image.open(mask_path).convert("L") + return Mask(to_image(image).squeeze(), dtype=torch.uint8) + def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: """Get dataset item for the index ``index``. @@ -272,17 +278,18 @@ def __getitem__(self, index: int) -> dict[str, str | torch.Tensor]: # Therefore, create empty mask for Normal (0) images. if isinstance(mask_path, str): mask_path = [mask_path] - masks = ( + semantic_mask = ( Mask(torch.zeros(image.shape[-2:])).to(torch.uint8) if label_index == LabelName.NORMAL - else Mask(torch.stack([read_mask(path, as_tensor=True) for path in mask_path])) + else Mask(torch.stack([self._read_mask(path) for path in mask_path])) ) - mask = Mask(masks.view(-1, *masks.shape[-2:]).any(dim=0).to(torch.uint8)) - item["image"], item["mask"] = self.transform(image, mask) if self.transform else (image, mask) + + binary_mask = Mask(semantic_mask.view(-1, *semantic_mask.shape[-2:]).int().any(dim=0).to(torch.uint8)) + item["image"], item["mask"] = self.transform(image, binary_mask) if self.transform else (image, binary_mask) item["mask_path"] = mask_path # List of masks with the original size for saturation based metrics calculation - item["masks"] = masks + item["semantic_mask"] = semantic_mask if self.task == TaskType.DETECTION: # create boxes from masks for detection task From 15e1e74bf83f91b7cbc482ee6c004f0f70d4fc80 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 1 May 2024 18:17:01 +0200 Subject: [PATCH 60/63] add metric collection for semantic pixel metrics --- src/anomalib/callbacks/metrics.py | 67 ++++++++++--------- .../normalization/min_max_normalization.py | 2 +- .../models/components/base/anomaly_module.py | 1 + 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 57ab0b3a27..f7d93debee 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -13,7 +13,7 @@ from lightning.pytorch.utilities.types import STEP_OUTPUT from anomalib import TaskType -from anomalib.metrics import SPRO, AnomalibMetricCollection, create_metric_collection +from anomalib.metrics import create_metric_collection from anomalib.models import AnomalyModule logger = logging.getLogger(__name__) @@ -84,9 +84,22 @@ def setup( ) else: pixel_metric_names = ( - self.pixel_metric_names if not isinstance(self.pixel_metric_names, str) else [self.pixel_metric_names] + self.pixel_metric_names.copy() + if not isinstance(self.pixel_metric_names, str) + else [self.pixel_metric_names] ) + semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] + if "SPRO" in pixel_metric_names: + if isinstance(pixel_metric_names, list): + pixel_metric_names.remove("SPRO") + semantic_pixel_metric_names = ["SPRO"] + elif isinstance(pixel_metric_names, dict): + spro_metric = pixel_metric_names.pop("SPRO") + semantic_pixel_metric_names = {"SPRO": spro_metric} + else: + logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names)) + if isinstance(pl_module, AnomalyModule): pl_module.image_metrics = create_metric_collection(image_metric_names, "image_") if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint @@ -96,6 +109,7 @@ def setup( pl_module.pixel_metrics.add_metrics(new_metrics[name]) else: pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") + pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_") self._set_threshold(pl_module) def on_validation_epoch_start( @@ -107,6 +121,7 @@ def on_validation_epoch_start( pl_module.image_metrics.reset() pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() def on_validation_batch_end( self, @@ -121,7 +136,7 @@ def on_validation_batch_end( if outputs is not None: self._outputs_to_device(outputs) - self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) + self._update_metrics(pl_module, outputs) def on_validation_epoch_end( self, @@ -142,6 +157,7 @@ def on_test_epoch_start( pl_module.image_metrics.reset() pl_module.pixel_metrics.reset() + pl_module.semantic_pixel_metrics.reset() def on_test_batch_end( self, @@ -156,7 +172,7 @@ def on_test_batch_end( if outputs is not None: self._outputs_to_device(outputs) - self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) + self._update_metrics(pl_module, outputs) def on_test_epoch_end( self, @@ -170,21 +186,21 @@ def on_test_epoch_end( def _set_threshold(self, pl_module: AnomalyModule) -> None: pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) + pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) def _update_metrics( self, - image_metric: AnomalibMetricCollection, - pixel_metric: AnomalibMetricCollection, + pl_module: AnomalyModule, output: STEP_OUTPUT, ) -> None: - image_metric.to(self.device) - image_metric.update(output["pred_scores"], output["label"].int()) + pl_module.image_metrics.to(self.device) + pl_module.image_metrics.update(output["pred_scores"], output["label"].int()) if "mask" in output and "anomaly_maps" in output: - pixel_metric.to(self.device) - if "semantic_mask" in output: - self._update_pixel_metrics(pixel_metric, output) - else: - pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + pl_module.pixel_metrics.to(self.device) + pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) + if "semantic_mask" in output and "anomaly_maps" in output: + pl_module.semantic_pixel_metrics.to(self.device) + pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: if isinstance(output, dict): @@ -197,26 +213,11 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any output[i] = self._outputs_to_device(value) return output - def _update_pixel_metrics(self, pixel_metric: AnomalibMetricCollection, output: STEP_OUTPUT) -> None: - """Handle metric updates when the SPRO metric is used alongside other pixel-level metrics.""" - update = False - for name, metric in pixel_metric.items(copy_state=False): - if isinstance(metric, SPRO): - metric.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"]) - else: - logger.warning( - f"Metric {name} may not be suitable for a dataset with the region separated " - "in multiple ground-truth masks.", - ) - metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) - update = True - pixel_metric.set_update_called(update) - @staticmethod def _log_metrics(pl_module: AnomalyModule) -> None: """Log computed performance metrics.""" - if pl_module.pixel_metrics._update_called: # noqa: SLF001 - pl_module.log_dict(pl_module.pixel_metrics, prog_bar=True) - pl_module.log_dict(pl_module.image_metrics, prog_bar=False) - else: - pl_module.log_dict(pl_module.image_metrics, prog_bar=True) + pl_module.log_dict(pl_module.image_metrics, prog_bar=True) + if pl_module.pixel_metrics.update_called: + pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False) + if pl_module.semantic_pixel_metrics.update_called: + pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False) diff --git a/src/anomalib/callbacks/normalization/min_max_normalization.py b/src/anomalib/callbacks/normalization/min_max_normalization.py index 4ff8c9b6e5..f22a36afd0 100644 --- a/src/anomalib/callbacks/normalization/min_max_normalization.py +++ b/src/anomalib/callbacks/normalization/min_max_normalization.py @@ -39,7 +39,7 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None: """Call when the test begins.""" del trainer # `trainer` variable is not used. - for metric in (pl_module.image_metrics, pl_module.pixel_metrics): + for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics): if metric is not None: metric.set_threshold(0.5) diff --git a/src/anomalib/models/components/base/anomaly_module.py b/src/anomalib/models/components/base/anomaly_module.py index 4ae12fb397..1e8d85b1f7 100644 --- a/src/anomalib/models/components/base/anomaly_module.py +++ b/src/anomalib/models/components/base/anomaly_module.py @@ -50,6 +50,7 @@ def __init__(self) -> None: self.image_metrics: AnomalibMetricCollection self.pixel_metrics: AnomalibMetricCollection + self.semantic_pixel_metrics: AnomalibMetricCollection self._transform: Transform | None = None self._input_size: tuple[int, int] | None = None From 69953edc22f6f5a266cd6912840dc9fc203b98a3 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 3 May 2024 08:01:46 +0200 Subject: [PATCH 61/63] add comment --- src/anomalib/callbacks/metrics.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index f7d93debee..98fced8357 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -89,7 +89,10 @@ def setup( else [self.pixel_metric_names] ) + # create a separate metric collection for metrics that operate over the semantic segmentation mask + # (segmentation mask with a separate channel for each defect type) semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] + # currently only SPRO metric is supported as semantic segmentation metric if "SPRO" in pixel_metric_names: if isinstance(pixel_metric_names, list): pixel_metric_names.remove("SPRO") From 3ffee165a2e3d93b1f31031b73a773b42728126d Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 16 May 2024 11:02:16 +0200 Subject: [PATCH 62/63] check if val_data is assigned for val_split_mode from_dir --- src/anomalib/data/base/datamodule.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/anomalib/data/base/datamodule.py b/src/anomalib/data/base/datamodule.py index 208f1370a2..bc5063d8ab 100644 --- a/src/anomalib/data/base/datamodule.py +++ b/src/anomalib/data/base/datamodule.py @@ -223,7 +223,10 @@ def _create_val_split(self) -> None: self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data) elif self.val_split_mode == ValSplitMode.FROM_DIR: # the val_data is prepared in subclass - pass + assert hasattr( + self, + "val_data", + ), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup." elif self.val_split_mode != ValSplitMode.NONE: msg = f"Unknown validation split mode: {self.val_split_mode}" raise ValueError(msg) From b7c0d94b3f804b8dc922b0a7d7da802d069b1ad6 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 16 May 2024 12:09:45 +0200 Subject: [PATCH 63/63] initialize semantic pixel metric names as list --- src/anomalib/callbacks/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/callbacks/metrics.py b/src/anomalib/callbacks/metrics.py index 98fced8357..8c32e6ec40 100644 --- a/src/anomalib/callbacks/metrics.py +++ b/src/anomalib/callbacks/metrics.py @@ -91,7 +91,7 @@ def setup( # create a separate metric collection for metrics that operate over the semantic segmentation mask # (segmentation mask with a separate channel for each defect type) - semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] + semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]] = [] # currently only SPRO metric is supported as semantic segmentation metric if "SPRO" in pixel_metric_names: if isinstance(pixel_metric_names, list):