Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Add MVTec LOCO dataset and sPRO metric #1967

Open
wants to merge 72 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 69 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
c5c6bc7
add FROM_DIR option to `val split mode` to support a provided val dir…
willyfh Jan 13, 2024
4c76092
add a conditional check for the FROM_DIR option of val split mode
willyfh Jan 13, 2024
ecc7169
add the mvtec loco ad dataset classes
willyfh Jan 13, 2024
9d89a68
add the default config file for mvtec loco ad dataset
willyfh Jan 13, 2024
a312be0
update initialization files to include MVTec LOCO dataset
willyfh Jan 13, 2024
37a6f6b
remove unnecessary Path conversion
willyfh Jan 13, 2024
a733bf7
add mvtec_loco.yaml to the readme documentation of configs
willyfh Jan 13, 2024
38ad1f4
add dummy image generation for mvtec loco dataset
willyfh Jan 13, 2024
05e3558
add unit test for mvtec loco dataset
willyfh Jan 13, 2024
519780a
update changelog to include the addition of mvtec loco dataset
willyfh Jan 13, 2024
c3c6a39
add mvtec loco dataset to the sphinx-based docs
willyfh Jan 14, 2024
ca91b8e
fix the malformed table
willyfh Jan 14, 2024
37f4dbc
binarize the masks and avoid the possibility of the merge_mask is None
willyfh Jan 14, 2024
3559a7c
Merge the masks using sum operation without binarization
willyfh Jan 21, 2024
83539b4
override getitem method to handle binarization and to add additional …
willyfh Jan 21, 2024
ddd33c5
Add saturation config to the datamodule
willyfh Jan 21, 2024
3af6eeb
Update the saturation config on the metrics based on the loaded confi…
willyfh Jan 21, 2024
53f2297
add masks as a keyword args to the update method of the AnomalibMetri…
willyfh Jan 21, 2024
3d00d69
Shorten the comments to solve ruff issues
willyfh Jan 21, 2024
f024e91
Add sPro metric implementation
willyfh Jan 21, 2024
9b8ca3b
Change the saturation threshold to tensor
willyfh Jan 21, 2024
29aaf46
Handle case with only background/normal images in scoring
willyfh Jan 21, 2024
f432753
rename spro metric and change the default value of saturation_config …
willyfh Jan 31, 2024
7d348d8
add unit test for spro metric
willyfh Jan 31, 2024
5597bfc
fix pre-commit issues
willyfh Jan 31, 2024
7b02863
handle file not found error when loading saturation config
willyfh Jan 31, 2024
6048a08
validate path before processing
willyfh Feb 1, 2024
e237347
update changelog with new PR
willyfh Feb 1, 2024
f9b67b8
Update src/anomalib/data/image/mvtec_loco.py
willyfh Feb 6, 2024
63bf8de
Update src/anomalib/data/image/mvtec_loco.py
willyfh Feb 6, 2024
0310e1b
Update tests/helpers/data.py
willyfh Feb 6, 2024
4eb2ec3
change assert to raise error
willyfh Feb 10, 2024
f3bccb8
return list of masks instead of merging the multiple masks from the d…
willyfh Feb 10, 2024
ac1ecb1
collate masks as a list of tensor to avoid stack error due to unequal…
willyfh Feb 10, 2024
3bbc750
update spro to handle list of masks and remove the _ args
willyfh Feb 10, 2024
e8e7360
update unit test to use list of masks as the target
willyfh Feb 10, 2024
ccef95c
update type and docstring of spro_score function
willyfh Feb 10, 2024
2cdeb82
remove _saturation_config attribute from metric collection module
willyfh Feb 10, 2024
d2cbcf3
remove unnecessary lines
willyfh Feb 11, 2024
829289b
add unit test to make sure the `mask` is binary
willyfh Feb 11, 2024
53aa07f
add warning when the saturation threshold is larger than the defect area
willyfh Feb 11, 2024
d9a2333
Move the loading process of saturation config from dataset to metric
willyfh Feb 11, 2024
c1b7a28
merge main
djdameln Apr 9, 2024
8785aa2
merge main
djdameln Apr 9, 2024
a767052
Merge branch 'mvtec_loco' into feature/mvtec-loco
djdameln Apr 9, 2024
930cfe2
Update src/anomalib/metrics/spro.py
samet-akcay Apr 9, 2024
a2c529c
update hashsum
djdameln Apr 9, 2024
0c68c90
update example
djdameln Apr 9, 2024
d6bd7fb
remove duplicate parameter
djdameln Apr 9, 2024
4a8e487
Merge branch 'feature/mvtec-loco' of github.com:openvinotoolkit/anoma…
djdameln Apr 9, 2024
934f753
Update src/anomalib/data/base/datamodule.py
djdameln Apr 9, 2024
63f210e
move and update loco config
djdameln Apr 9, 2024
98cd7f6
Merge branch 'feature/mvtec-loco' of github.com:openvinotoolkit/anoma…
djdameln Apr 9, 2024
ba371c4
Update tests/unit/data/image/test_mvtec_loco.py
djdameln Apr 9, 2024
114f8e6
Update tests/unit/metrics/test_spro.py
djdameln Apr 9, 2024
7dbb983
Update src/anomalib/data/image/mvtec_loco.py
djdameln Apr 9, 2024
6587cac
Update src/anomalib/cli/cli.py
djdameln Apr 9, 2024
36ebf1c
Update src/anomalib/metrics/spro.py
djdameln Apr 9, 2024
b998823
Update src/anomalib/metrics/spro.py
djdameln Apr 9, 2024
c844c1d
fix mask loading
djdameln Apr 9, 2024
501c116
Merge branch 'feature/mvtec-loco' of github.com:openvinotoolkit/anoma…
djdameln Apr 9, 2024
58c3b3f
ruff
djdameln Apr 9, 2024
4028b9a
fix multiple mask loading
djdameln Apr 9, 2024
c87f742
masks -> semantic_mask
djdameln Apr 10, 2024
eabb755
Merge branch 'main' into feature/mvtec-loco
djdameln Apr 10, 2024
15e1e74
add metric collection for semantic pixel metrics
djdameln May 1, 2024
69953ed
add comment
djdameln May 3, 2024
38d98c2
merge main
djdameln May 16, 2024
3ffee16
check if val_data is assigned for val_split_mode from_dir
djdameln May 16, 2024
b7c0d94
initialize semantic pixel metric names as list
djdameln May 16, 2024
9da8a68
Merge branch 'main' of github.com:openvinotoolkit/anomalib into featu…
samet-akcay May 17, 2024
1cc19d7
Merge branch 'main' into feature/mvtec-loco
ashwinvaidya17 May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Added

- Add support for MVTec LOCO AD dataset and sPRO metric by @willyfh in https://github.com/openvinotoolkit/anomalib/pull/1686
- πŸš€ Update OpenVINO and ONNX export to support fixed input shape by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2006
- Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018

Expand Down
1 change: 1 addition & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ configs/
β”‚ β”œβ”€β”€ kolektor.yaml
β”‚ β”œβ”€β”€ mvtec_3d.yaml
β”‚ β”œβ”€β”€ mvtec.yaml
β”‚ β”œβ”€β”€ mvtec_loco.yaml
β”‚ β”œβ”€β”€ shanghaitec.yaml
β”‚ β”œβ”€β”€ ucsd_ped.yaml
β”‚ └── visa.yaml
Expand Down
13 changes: 13 additions & 0 deletions configs/data/mvtec_loco.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class_path: anomalib.data.MVTecLoco
init_args:
root: ./datasets/MVTec_LOCO
category: breakfast_box
train_batch_size: 32
eval_batch_size: 32
num_workers: 8
task: SEGMENTATION
test_split_mode: FROM_DIR
test_split_ratio: 0.2
val_split_mode: FROM_DIR
val_split_ratio: 0.5
seed: null
8 changes: 8 additions & 0 deletions docs/source/markdown/guides/reference/data/image/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ Learn more about Kolektor dataset.
Learn more about MVTec 2D dataset
:::

:::{grid-item-card} MVTec LOCO
:link: ./mvtec_loco
:link-type: doc

Learn more about MVTec LOCO dataset
:::

:::{grid-item-card} Visa
:link: ./visa
:link-type: doc
Expand All @@ -47,5 +54,6 @@ Learn more about Visa dataset.
./folder
./kolektor
./mvtec
./mvtec_loco
./visa
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# MVTec LOCO Data

```{eval-rst}
.. automodule:: anomalib.data.image.mvtec_loco
:members:
:show-inheritance:
```
58 changes: 41 additions & 17 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib import TaskType
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
from anomalib.metrics import create_metric_collection
from anomalib.models import AnomalyModule

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,8 +67,7 @@ def setup(
pl_module (AnomalyModule): Anomalib Model that inherits pl LightningModule.
stage (str | None, optional): fit, validate, test or predict. Defaults to None.
"""
del trainer, stage # These variables are not used.

del stage, trainer # this variable is not used.
image_metric_names = [] if self.image_metric_names is None else self.image_metric_names
if isinstance(image_metric_names, str):
image_metric_names = [image_metric_names]
Expand All @@ -85,9 +84,25 @@ def setup(
)
else:
pixel_metric_names = (
self.pixel_metric_names if not isinstance(self.pixel_metric_names, str) else [self.pixel_metric_names]
self.pixel_metric_names.copy()
if not isinstance(self.pixel_metric_names, str)
else [self.pixel_metric_names]
)

# create a separate metric collection for metrics that operate over the semantic segmentation mask
# (segmentation mask with a separate channel for each defect type)
semantic_pixel_metric_names: list[str] | dict[str, dict[str, Any]]
# currently only SPRO metric is supported as semantic segmentation metric
if "SPRO" in pixel_metric_names:
if isinstance(pixel_metric_names, list):
pixel_metric_names.remove("SPRO")
semantic_pixel_metric_names = ["SPRO"]
elif isinstance(pixel_metric_names, dict):
spro_metric = pixel_metric_names.pop("SPRO")
semantic_pixel_metric_names = {"SPRO": spro_metric}
else:
logger.warning("Unexpected type for pixel_metric_names: %s", type(pixel_metric_names))

if isinstance(pl_module, AnomalyModule):
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint
Expand All @@ -97,6 +112,7 @@ def setup(
pl_module.pixel_metrics.add_metrics(new_metrics[name])
else:
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
pl_module.semantic_pixel_metrics = create_metric_collection(semantic_pixel_metric_names, "pixel_")
self._set_threshold(pl_module)

def on_validation_epoch_start(
Expand All @@ -108,6 +124,7 @@ def on_validation_epoch_start(

pl_module.image_metrics.reset()
pl_module.pixel_metrics.reset()
pl_module.semantic_pixel_metrics.reset()

def on_validation_batch_end(
self,
Expand All @@ -122,7 +139,7 @@ def on_validation_batch_end(

if outputs is not None:
self._outputs_to_device(outputs)
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)
self._update_metrics(pl_module, outputs)

def on_validation_epoch_end(
self,
Expand All @@ -143,6 +160,7 @@ def on_test_epoch_start(

pl_module.image_metrics.reset()
pl_module.pixel_metrics.reset()
pl_module.semantic_pixel_metrics.reset()

def on_test_batch_end(
self,
Expand All @@ -157,7 +175,7 @@ def on_test_batch_end(

if outputs is not None:
self._outputs_to_device(outputs)
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)
self._update_metrics(pl_module, outputs)

def on_test_epoch_end(
self,
Expand All @@ -171,32 +189,38 @@ def on_test_epoch_end(
def _set_threshold(self, pl_module: AnomalyModule) -> None:
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
pl_module.semantic_pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())

def _update_metrics(
self,
image_metric: AnomalibMetricCollection,
pixel_metric: AnomalibMetricCollection,
pl_module: AnomalyModule,
output: STEP_OUTPUT,
) -> None:
image_metric.to(self.device)
image_metric.update(output["pred_scores"], output["label"].int())
pl_module.image_metrics.to(self.device)
pl_module.image_metrics.update(output["pred_scores"], output["label"].int())
if "mask" in output and "anomaly_maps" in output:
pixel_metric.to(self.device)
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
pl_module.pixel_metrics.to(self.device)
pl_module.pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
if "semantic_mask" in output and "anomaly_maps" in output:
pl_module.semantic_pixel_metrics.to(self.device)
pl_module.semantic_pixel_metrics.update(torch.squeeze(output["anomaly_maps"]), output["semantic_mask"])

def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
if isinstance(output, dict):
for key, value in output.items():
output[key] = self._outputs_to_device(value)
elif isinstance(output, torch.Tensor):
output = output.to(self.device)
elif isinstance(output, list):
for i, value in enumerate(output):
output[i] = self._outputs_to_device(value)
return output

@staticmethod
def _log_metrics(pl_module: AnomalyModule) -> None:
"""Log computed performance metrics."""
if pl_module.pixel_metrics._update_called: # noqa: SLF001
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=True)
pl_module.log_dict(pl_module.image_metrics, prog_bar=False)
else:
pl_module.log_dict(pl_module.image_metrics, prog_bar=True)
pl_module.log_dict(pl_module.image_metrics, prog_bar=True)
if pl_module.pixel_metrics.update_called:
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=False)
if pl_module.semantic_pixel_metrics.update_called:
pl_module.log_dict(pl_module.semantic_pixel_metrics, prog_bar=False)
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
"""Call when the test begins."""
del trainer # `trainer` variable is not used.

for metric in (pl_module.image_metrics, pl_module.pixel_metrics):
for metric in (pl_module.image_metrics, pl_module.pixel_metrics, pl_module.semantic_pixel_metrics):
if metric is not None:
metric.set_threshold(0.5)

Expand Down
13 changes: 11 additions & 2 deletions src/anomalib/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,17 @@ def add_arguments_to_parser(self, parser: ArgumentParser) -> None:

parser.add_function_arguments(get_normalization_callback, "normalization")
parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION)
parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
parser.add_argument(
"--metrics.image",
type=list[str] | str | dict[str, dict[str, Any]] | None,
default=["F1Score", "AUROC"],
)
parser.add_argument(
"--metrics.pixel",
type=list[str] | str | dict[str, dict[str, Any]] | None,
default=None,
required=False,
)
djdameln marked this conversation as resolved.
Show resolved Hide resolved
parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
parser.add_argument("--logging.log_graph", type=bool, help="Log the model to the logger", default=False)
if hasattr(parser, "subcommand") and parser.subcommand not in ("export", "predict"):
Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .base import AnomalibDataModule, AnomalibDataset
from .depth import DepthDataFormat, Folder3D, MVTec3D
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, Visa
from .image import BTech, Folder, ImageDataFormat, Kolektor, MVTec, MVTecLoco, Visa
from .predict import PredictDataset
from .utils import LabelName
from .video import Avenue, ShanghaiTech, UCSDped, VideoDataFormat
Expand Down Expand Up @@ -63,6 +63,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
"Kolektor",
"MVTec",
"MVTec3D",
"MVTecLoco",
"Avenue",
"UCSDped",
"ShanghaiTech",
Expand Down
16 changes: 15 additions & 1 deletion src/anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
def collate_fn(batch: list) -> dict[str, Any]:
"""Collate bounding boxes as lists.

Bounding boxes are collated as a list of tensors, while the default collate function is used for all other entries.
Bounding boxes and `masks` (not `mask`) are collated as a list of tensors. If `masks` exists,
the `mask_path` is also collated as a list since each element in the batch could be unequal.
For all other entries, the default collate function is used.

Args:
batch (List): list of items in the batch where len(batch) is equal to the batch size.
Expand All @@ -42,6 +44,12 @@ def collate_fn(batch: list) -> dict[str, Any]:
if "boxes" in elem:
# collate boxes as list
out_dict["boxes"] = [item.pop("boxes") for item in batch]
if "semantic_mask" in elem:
# semantic masks have a variable number of channels, so we collate them as a list
out_dict["semantic_mask"] = [item.pop("semantic_mask") for item in batch]
if "mask_path" in elem and isinstance(elem["mask_path"], list):
# collate mask paths as list
out_dict["mask_path"] = [item.pop("mask_path") for item in batch]
# collate other data normally
out_dict.update({key: default_collate([item[key] for item in batch]) for key in elem})
return out_dict
Expand Down Expand Up @@ -213,6 +221,12 @@ def _create_val_split(self) -> None:
# converted from random training sample
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed)
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
elif self.val_split_mode == ValSplitMode.FROM_DIR:
# the val_data is prepared in subclass
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
assert hasattr(
self,
"val_data",
), f"FROM_DIR is not supported for {self.__class__.__name__} which does not assign val_data in _setup."
elif self.val_split_mode != ValSplitMode.NONE:
msg = f"Unknown validation split mode: {self.val_split_mode}"
raise ValueError(msg)
Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/data/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .folder import Folder
from .kolektor import Kolektor
from .mvtec import MVTec
from .mvtec_loco import MVTecLoco
from .visa import Visa


Expand All @@ -21,11 +22,12 @@ class ImageDataFormat(str, Enum):

MVTEC = "mvtec"
MVTEC_3D = "mvtec_3d"
MVTEC_LOCO = "mvtec_loco"
BTECH = "btech"
KOLEKTOR = "kolektor"
FOLDER = "folder"
FOLDER_3D = "folder_3d"
VISA = "visa"


__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "Visa"]
__all__ = ["BTech", "Folder", "Kolektor", "MVTec", "MVTecLoco", "Visa"]
Loading
Loading