Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ Refactor Visualisation #1693

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/anomalib/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,10 @@
from .model_loader import LoadModelCallback
from .tiler_configuration import TilerConfigurationCallback
from .timer import TimerCallback
from .visualizer import ImageVisualizerCallback, MetricVisualizerCallback, get_visualization_callbacks

__all__ = [
"get_visualization_callbacks",
"GraphLogger",
"ImageVisualizerCallback",
"LoadModelCallback",
"MetricVisualizerCallback",
"TilerConfigurationCallback",
"TimerCallback",
]
Expand Down
171 changes: 171 additions & 0 deletions src/anomalib/callbacks/visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Visualizer Callback.

This is assigned by Anomalib Engine internally.
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from pathlib import Path
from typing import Any, cast

from lightning.pytorch import Callback, Trainer
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.data.utils.image import save_image, show_image
from anomalib.loggers import AnomalibWandbLogger
from anomalib.loggers.base import ImageLoggerBase
from anomalib.models import AnomalyModule
from anomalib.utils.visualization import (
BaseVisualizer,
GeneratorResult,
VisualizationStep,
)

logger = logging.getLogger(__name__)


class _VisualizationCallback(Callback):
"""Callback for visualization that is used internally by the Engine.

Args:
visualizers (BaseVisualizer | list[BaseVisualizer]):
Visualizer objects that are used for computing the visualizations. Defaults to None.
save (bool, optional): Save the image. Defaults to False.
root (Path | None, optional): The path to save the images. Defaults to None.
log (bool, optional): Log the images into the loggers. Defaults to False.
show (bool, optional): Show the images. Defaults to False.

Example:
>>> visualizers = [ImageVisualizer(), MetricsVisualizer()]
>>> visualization_callback = _VisualizationCallback(
... visualizers=visualizers,
... save=True,
... root="results/images"
... )

CLI
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers ImageVisualizer \
--visualization.visualizers+=MetricsVisualizer
or
$ anomalib train --model Padim --data MVTec \
--visualization.visualizers '[ImageVisualizer, MetricsVisualizer]'

samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
Raises:
ValueError: Incase `root` is None and `save` is True.
"""

def __init__(
self,
visualizers: BaseVisualizer | list[BaseVisualizer],
save: bool = False,
root: Path | None = None,
log: bool = False,
show: bool = False,
) -> None:
self.save = save
if save and root is None:
msg = "`root` must be provided if save is True"
raise ValueError(msg)
self.root: Path = root if root is not None else Path() # need this check for mypy
self.log = log
self.show = show
self.generators = visualizers if isinstance(visualizers, list) else [visualizers]

def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.BATCH:
for result in generator(
trainer=trainer,
pl_module=pl_module,
outputs=outputs,
batch=batch,
batch_idx=batch_idx,
dataloader_idx=dataloader_idx,
):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
save_image(image=result.image, root=self.root, filename=result.file_name)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)

def on_test_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
for generator in self.generators:
if generator.visualize_on == VisualizationStep.STAGE_END:
for result in generator(trainer=trainer, pl_module=pl_module):
if self.save:
if result.file_name is None:
msg = "``save`` is set to ``True`` but file name is ``None``"
raise ValueError(msg)
save_image(image=result.image, root=self.root, filename=result.file_name)
if self.show:
show_image(image=result.image, title=str(result.file_name))
if self.log:
self._add_to_logger(result, pl_module, trainer)

for logger in trainer.loggers:
if isinstance(logger, AnomalibWandbLogger):
logger.save()

def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
return self.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_predict_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
return self.on_test_end(trainer, pl_module)

def _add_to_logger(
self,
result: GeneratorResult,
module: AnomalyModule,
trainer: Trainer,
) -> None:
"""Add image to logger.

Args:
result (GeneratorResult): Output from the generators.
module (AnomalyModule): LightningModule from which the global step is extracted.
trainer (Trainer): Trainer object.
"""
# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().replace("logger", "").replace("anomalib", ""): logger
for logger in trainer.loggers
}
# save image to respective logger
if result.file_name is None:
msg = "File name is None"
raise ValueError(msg)
filename = result.file_name
image = result.image
for log_to in available_loggers:
# check if logger object is same as the requested object
if isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
_name = filename.parent.name + "_" + filename.name if isinstance(filename, Path) else filename
logger.add_image(
image=image,
name=_name,
global_step=module.global_step,
)
16 changes: 0 additions & 16 deletions src/anomalib/callbacks/visualizer/__init__.py

This file was deleted.

48 changes: 0 additions & 48 deletions src/anomalib/callbacks/visualizer/utils.py

This file was deleted.

107 changes: 0 additions & 107 deletions src/anomalib/callbacks/visualizer/visualizer_base.py

This file was deleted.

Loading
Loading