diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index 105d6052ea..fe97de1d99 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -31,16 +31,14 @@ from torchtnt.framework.callbacks.base_checkpointer import ( BaseCheckpointer as BaseCheckpointer, ) -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - RestoreOptions, -) +from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.callbacks.lambda_callback import Lambda from torchtnt.framework.fit import fit from torchtnt.framework.state import State from torchtnt.framework.train import train from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData +from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import init_from_env from torchtnt.utils.test_utils import skip_if_not_distributed diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index ded257c412..3cf4cd1746 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -16,10 +16,7 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - RestoreOptions, -) +from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.state import EntryPoint, State from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit from torchtnt.framework.utils import get_timing_context @@ -28,6 +25,7 @@ _metadata_exists, _sort_by_metric_value, _sort_by_recency, + BestCheckpointConfig, get_best_checkpoint_path, get_checkpoint_dirpaths, get_latest_checkpoint_path, diff --git a/torchtnt/framework/callbacks/checkpointer_types.py b/torchtnt/framework/callbacks/checkpointer_types.py index 5ccdf7862d..d7ab2693cf 100644 --- a/torchtnt/framework/callbacks/checkpointer_types.py +++ b/torchtnt/framework/callbacks/checkpointer_types.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Literal, Optional +from typing import Optional # TODO: eventually support overriding all knobs @@ -39,17 +39,3 @@ class RestoreOptions: restore_eval_progress: bool = True restore_optimizers: bool = True restore_lr_schedulers: bool = True - - -@dataclass -class BestCheckpointConfig: - """ - Config for saving the best checkpoints. - - Args: - monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit. - mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric. - """ - - monitored_metric: str - mode: Literal["min", "max"] = "min" diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 3984d6fb86..bd36cd7f05 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -23,11 +23,7 @@ ) from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - KnobOptions, - RestoreOptions, -) +from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, @@ -38,6 +34,7 @@ ) from torchtnt.framework.utils import get_timing_context from torchtnt.utils.optimizer import init_optim_state +from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index b5138cd681..ab2a0515d6 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -22,11 +22,7 @@ ) from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - KnobOptions, - RestoreOptions, -) +from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, @@ -36,6 +32,7 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context +from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import Stateful diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index 06cb6d33da..09252013f0 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -7,6 +7,7 @@ # pyre-strict from .checkpoint import ( + BestCheckpointConfig, CheckpointPath, get_best_checkpoint_path, get_checkpoint_dirpaths, @@ -160,4 +161,5 @@ "get_best_checkpoint_path", "get_checkpoint_dirpaths", "get_latest_checkpoint_path", + "BestCheckpointConfig", ] diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index d5ddc4b2f6..a677da03ff 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -31,6 +31,20 @@ class MetricData: value: float +@dataclass +class BestCheckpointConfig: + """ + Config for saving the best checkpoints. + + Args: + monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit. + mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric. + """ + + monitored_metric: str + mode: Literal["min", "max"] = "min" + + @total_ordering class CheckpointPath: """