diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index fb8fea71d7..105d6052ea 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -249,7 +249,7 @@ def test_restore_from_latest_empty_dir(self) -> None: self.assertEqual( log.output, [ - f"WARNING:torchtnt.framework.callbacks._checkpoint_utils:Input dirpath doesn't contain any subdirectories: {temp_dir}" + f"WARNING:torchtnt.utils.checkpoint:Input dirpath doesn't contain any subdirectories: {temp_dir}" ], ) self.assertFalse(restored) diff --git a/tests/framework/callbacks/test_checkpoint_utils.py b/tests/framework/callbacks/test_checkpoint_utils.py index ac1a019a98..f917fcd942 100644 --- a/tests/framework/callbacks/test_checkpoint_utils.py +++ b/tests/framework/callbacks/test_checkpoint_utils.py @@ -6,411 +6,16 @@ # pyre-strict -import os -import shutil -import tempfile import unittest -import torch -import torch.distributed as dist -from torch import nn -from torchsnapshot import Snapshot -from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME from torchtnt.framework._test_utils import DummyTrainUnit, get_dummy_train_state from torchtnt.framework.callbacks._checkpoint_utils import ( - _delete_checkpoint, - _metadata_exists, _prepare_app_state_for_checkpoint, - _retrieve_checkpoint_dirpaths, - _sort_by_metric_value, - _sort_by_recency, - get_best_checkpoint_path, - get_checkpoint_dirpaths, - get_latest_checkpoint_path, - rank_zero_read_and_broadcast, ) -from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process -from torchtnt.utils.env import init_from_env -from torchtnt.utils.fsspec import get_filesystem -from torchtnt.utils.test_utils import skip_if_not_distributed - -METADATA_FNAME: str = ".metadata" class CheckpointUtilsTest(unittest.TestCase): - @staticmethod - def _create_snapshot_metadata(output_dir: str) -> None: - path = os.path.join(output_dir, METADATA_FNAME) - with open(path, "w"): - pass - - def test_latest_checkpoint_path(self) -> None: - with tempfile.TemporaryDirectory() as temp_dir: - self.assertIsNone(get_latest_checkpoint_path(temp_dir)) - - with tempfile.TemporaryDirectory() as temp_dir: - latest_path = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(latest_path) - self.assertEqual( - get_latest_checkpoint_path(temp_dir), - latest_path, - ) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), - None, - ) - self._create_snapshot_metadata(latest_path) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), - latest_path, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - path_1 = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path_1) - self._create_snapshot_metadata(path_1) - path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002") - os.mkdir(path_2) - self._create_snapshot_metadata(path_2) - - # Missing metadata file - path_3 = os.path.join(temp_dir, "epoch_1_step_100") - os.mkdir(path_3) - - # Ill-formatted name - path_4 = os.path.join(temp_dir, "epoch_700") - os.mkdir(path_4) - self.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 - ) - - @skip_if_not_distributed - def test_latest_checkpoint_path_distributed(self) -> None: - spawn_multi_process( - 2, - "gloo", - self._latest_checkpoint_path_distributed, - ) - - @staticmethod - def _latest_checkpoint_path_distributed() -> None: - tc = unittest.TestCase() - is_rank0 = get_global_rank() == 0 - - if is_rank0: - temp_dir = tempfile.mkdtemp() - else: - temp_dir = "" - tc.assertIsNone(get_latest_checkpoint_path(temp_dir)) - if is_rank0: - shutil.rmtree(temp_dir) # delete temp directory - - if is_rank0: - temp_dir = tempfile.mkdtemp() - path_1 = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path_1) - CheckpointUtilsTest._create_snapshot_metadata(path_1) - path_2 = os.path.join(temp_dir, "epoch_0_step_100") - os.mkdir(path_2) - CheckpointUtilsTest._create_snapshot_metadata(path_2) - - # Missing metadata file - path_3 = os.path.join(temp_dir, "epoch_1_step_100") - os.mkdir(path_3) - - # Ill-formatted name - path_4 = os.path.join(temp_dir, "epoch_700") - os.mkdir(path_4) - else: - temp_dir = "" - path_2 = "" - - pg = PGWrapper(dist.group.WORLD) - path_container = [path_2] if is_rank0 else [None] - pg.broadcast_object_list(path_container, 0) - expected_path = path_container[0] - tc.assertIsNotNone(expected_path) - tc.assertEqual( - get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path - ) - - if is_rank0: - shutil.rmtree(temp_dir) # delete temp directory - - def test_best_checkpoint_path(self) -> None: - with tempfile.TemporaryDirectory() as temp_dir: - self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) - - # no checkpoint w/ metric value - path = os.path.join(temp_dir, "epoch_0_step_0") - os.mkdir(path) - self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) - - with tempfile.TemporaryDirectory() as temp_dir: - best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01") - os.mkdir(best_path) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min"), - best_path, - ) - self.assertIsNone( - get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), - None, - ) - self._create_snapshot_metadata(best_path) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), - best_path, - ) - - # handle negative values - best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") - os.mkdir(best_path_2) - self.assertEqual( - get_best_checkpoint_path(temp_dir, "val_loss", "min"), - best_path_2, - ) - - # handle "max" mode correctly - best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1") - os.mkdir(best_path_3) - self.assertEqual( - get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), - best_path_3, - ) - - # handle different metric correctly - best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2") - os.mkdir(best_path_4) - self.assertEqual( - get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), - best_path_3, - ) - self.assertEqual( - get_best_checkpoint_path( - temp_dir, metric_name="train_loss", mode="max" - ), - best_path_4, - ) - - def test_retrieve_checkpoint_dirpaths(self) -> None: - """ - Tests retrieving checkpoint directories from a given root directory - """ - with tempfile.TemporaryDirectory() as temp_dir: - paths = [ - "epoch_0_step_10", - "epoch_1_step_10", - "epoch_2_step_10", - "epoch_0_step_5", - "epoch_0_step_6", - "epoch_0_step_3", - ] - for path in paths[:-1]: - os.mkdir(os.path.join(temp_dir, path)) - # make last path a file instead of a directory - with open(os.path.join(temp_dir, paths[-1]), "w"): - pass - - # compares set equality since order of returned dirpaths is not guaranteed - # in _retrieve_checkpoint_dirpaths - self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), - {os.path.join(temp_dir, path) for path in paths[:-1]}, - ) - self.assertEqual( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), - [], - ) - - # check metadata file is correct filtered for - # by creating metadata for 3rd path in list - with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"): - pass - - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata") - ), - {os.path.join(temp_dir, paths[2])}, - ) - - def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: - """ - Tests retrieving checkpoint (w/ metrics) directories from a given root directory - """ - with tempfile.TemporaryDirectory() as temp_dir: - paths = [ - "epoch_0_step_10_val_loss=10", - "epoch_1_step_10_val_loss=5", - "epoch_2_step_10", - "epoch_0_step_5", - "epoch_0_step_6_train_loss=13", - ] - for path in paths: - os.mkdir(os.path.join(temp_dir, path)) - # make last path a file instead of a directory - with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"): - pass - - # compares set equality since order of returned dirpaths is not guaranteed - # in _retrieve_checkpoint_dirpaths - self.assertEqual( - set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), - {os.path.join(temp_dir, path) for path in paths}, - ) - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( - temp_dir, metadata_fname=None, metric_name="val_loss" - ) - ), - { - os.path.join(temp_dir, path) for path in paths[:2] - }, # since last path is a file - ) - self.assertEqual( - _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), - [], - ) - - # check metadata file is correct filtered for - # by creating metadata for 3rd path in list - with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"): - pass - - self.assertEqual( - set( - _retrieve_checkpoint_dirpaths( - temp_dir, metadata_fname=".metadata", metric_name="val_loss" - ) - ), - {os.path.join(temp_dir, paths[1])}, - ) - - @skip_if_not_distributed - def test_distributed_get_checkpoint_dirpaths(self) -> None: - spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths) - - @staticmethod - def _distributed_get_checkpoint_dirpaths() -> None: - """ - Tests that existing checkpoint directories are read and - properly registered on all ranks - """ - - @rank_zero_read_and_broadcast - def create_tmp_dir() -> str: - return tempfile.mkdtemp() - - init_from_env() - - temp_dir = create_tmp_dir() - try: - path1 = os.path.join(temp_dir, "epoch_0_step_10") - path2 = os.path.join(temp_dir, "epoch_1_step_20") - if get_global_rank() == 0: - os.mkdir(path1) - os.mkdir(path2) - torch.distributed.barrier() - - ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir) - tc = unittest.TestCase() - tc.assertEqual(set(ckpt_dirpaths), {path1, path2}) - - tc.assertEqual( - get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), [] - ) - finally: - if get_global_rank() == 0: - shutil.rmtree(temp_dir) # delete temp directory - - def test_get_checkpoint_dirpaths(self) -> None: - """ - Tests that `get_checkpoint_dirpaths` returns - the sorted checkpoint directories correctly - """ - with tempfile.TemporaryDirectory() as temp_dir: - path1 = os.path.join(temp_dir, "epoch_1_step_20") - path2 = os.path.join(temp_dir, "epoch_4_step_130") - path3 = os.path.join(temp_dir, "epoch_0_step_10") - os.mkdir(path1) - os.mkdir(path2) - os.mkdir(path3) - - self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir)), - {path1, path2, path3}, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01") - path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2") - path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12") - os.mkdir(path1) - os.mkdir(path2) - os.mkdir(path3) - - self.assertEqual( - set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")), - {path1, path2, path3}, - ) - - with tempfile.TemporaryDirectory() as temp_dir: - self.assertEqual( - get_checkpoint_dirpaths(temp_dir), - [], - ) - - def test_checkpoint_sorting_utils(self) -> None: - """ - Tests the sort utilities - """ - paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"] - self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]]) - - paths = [ - "epoch_1_step_20_val_loss=0.09", - "epoch_4_step_130_val_loss=29", - "epoch_0_step_10_val_loss=10", - ] - self.assertEqual( - _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]] - ) - self.assertEqual( - _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]] - ) - - def test_delete_checkpoint(self) -> None: - """ - Tests removing checkpoint directories - """ - app_state = {"module": nn.Linear(2, 2)} - with tempfile.TemporaryDirectory() as temp_dir: - dirpath = os.path.join(temp_dir, "checkpoint") - Snapshot.take(dirpath, app_state=app_state) - self.assertTrue(os.path.exists(dirpath)) - # check that error is thrown if .snapshot_metadata is not found in the directory when deleting - os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - with self.assertRaisesRegex( - RuntimeError, f"{temp_dir} does not contain .snapshot_metadata" - ): - _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME) - _delete_checkpoint(dirpath) - self.assertFalse(os.path.exists(dirpath)) - - def test_metadata_exists(self) -> None: - app_state = {"module": nn.Linear(2, 2)} - with tempfile.TemporaryDirectory() as temp_dir: - dirpath = os.path.join(temp_dir, "checkpoint") - Snapshot.take(dirpath, app_state=app_state) - - fs = get_filesystem(dirpath) - self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) - - os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) - self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) def test_get_app_state(self) -> None: my_unit = DummyTrainUnit(input_dim=2) diff --git a/tests/utils/test_checkpoint.py b/tests/utils/test_checkpoint.py new file mode 100644 index 0000000000..2257e683c2 --- /dev/null +++ b/tests/utils/test_checkpoint.py @@ -0,0 +1,549 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import os +import shutil +import tempfile +import unittest + +import torch + +import torch.distributed as dist +from torch import nn +from torchsnapshot import Snapshot +from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME +from torchtnt.utils import get_global_rank, init_from_env + +from torchtnt.utils.checkpoint import ( + _delete_checkpoint, + _metadata_exists, + _retrieve_checkpoint_dirpaths, + _sort_by_metric_value, + _sort_by_recency, + CheckpointPath, + get_best_checkpoint_path, + get_checkpoint_dirpaths, + get_latest_checkpoint_path, + MetricData, +) +from torchtnt.utils.distributed import ( + PGWrapper, + rank_zero_read_and_broadcast, + spawn_multi_process, +) +from torchtnt.utils.fsspec import get_filesystem +from torchtnt.utils.test_utils import skip_if_not_distributed + +METADATA_FNAME: str = ".metadata" + + +class CheckpointPathTest(unittest.TestCase): + def test_from_str(self) -> None: + # invalid paths + malformed_paths = [ + "foo/step_20", + "foo/epoch_50", + "epoch_30", + "foo/epoch_20_step", + "foo/epoch_20_step_30_val_loss=1a", + "foo/epoch_2_step_15_mean=hello", + "foo/epoch_2.6_step_23", + ] + for path in malformed_paths: + with self.assertRaisesRegex( + ValueError, f"Attempted to parse malformed checkpoint path: {path}" + ): + CheckpointPath.from_str(path) + + # valid paths + valid_paths = [ + ("foo/epoch_0_step_1", CheckpointPath("foo", epoch=0, step=1)), + ( + "foo/epoch_14_step_3_mean=15.0", + CheckpointPath( + "foo", epoch=14, step=3, metric_data=MetricData("mean", 15.0) + ), + ), + ( + "foo/epoch_14_step_3_loss=-27.35", + CheckpointPath( + "foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) + ), + ), + ( + "/foo/epoch_14_step_3_loss=-27.35", + CheckpointPath( + "/foo", epoch=14, step=3, metric_data=MetricData("loss", -27.35) + ), + ), + ( + "foo/bar/epoch_23_step_31_mean_loss_squared=0.0", + CheckpointPath( + "foo/bar/", + epoch=23, + step=31, + metric_data=MetricData("mean_loss_squared", 0.0), + ), + ), + ( + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61/epoch_2_step_1_acc=0.98", + CheckpointPath( + "oss://some/path/checkpoints/0b20e70f-9ad2-4904-b7d6-e8da48087d61", + epoch=2, + step=1, + metric_data=MetricData("acc", 0.98), + ), + ), + ] + for path, expected_ckpt in valid_paths: + parsed_ckpt = CheckpointPath.from_str(path) + self.assertEqual(parsed_ckpt, expected_ckpt) + self.assertEqual(parsed_ckpt.path, path) + + # with a trailing slash + ckpt = CheckpointPath.from_str("foo/epoch_0_step_1/") + self.assertEqual(ckpt, CheckpointPath("foo", epoch=0, step=1)) + self.assertEqual(ckpt.path, "foo/epoch_0_step_1") + + def test_compare_by_recency(self) -> None: + old = CheckpointPath("foo", epoch=0, step=1) + new = CheckpointPath("foo", epoch=1, step=1) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + self.assertFalse(new == old) + + old = CheckpointPath("foo", epoch=3, step=5) + new = CheckpointPath("foo", epoch=3, step=9) + self.assertTrue(new.newer_than(old)) + self.assertFalse(old.newer_than(new)) + self.assertFalse(new == old) + + twin1 = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) + ) + almost_twin = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("bar", 2.0) + ) + + self.assertFalse(twin1.newer_than(almost_twin)) + self.assertFalse(almost_twin.newer_than(twin1)) + self.assertFalse(twin1 == almost_twin) + + twin2 = CheckpointPath( + "foo", epoch=2, step=5, metric_data=MetricData("foo", 1.0) + ) + self.assertTrue(twin1 == twin2) + + def test_compare_by_optimality(self) -> None: + # not both metric aware + ckpt1 = CheckpointPath("foo", epoch=0, step=1) + ckpt2 = CheckpointPath("foo", epoch=1, step=1) + ckpt3 = CheckpointPath( + "foo", epoch=1, step=1, metric_data=MetricData("bar", 1.0) + ) + for ckpt in [ckpt2, ckpt3]: + with self.assertRaisesRegex( + AssertionError, + "Attempted to compare optimality of non metric-aware checkpoints", + ): + ckpt1.more_optimal_than(ckpt, mode="min") + + # tracking different metrics + ckpt4 = CheckpointPath( + "foo", epoch=1, step=1, metric_data=MetricData("baz", 1.0) + ) + with self.assertRaisesRegex( + AssertionError, + "Attempted to compare optimality of checkpoints tracking different metrics", + ): + ckpt3.more_optimal_than(ckpt4, mode="min") + + smaller = CheckpointPath( + "foo", epoch=0, step=1, metric_data=MetricData("foo", 1.0) + ) + larger = CheckpointPath( + "foo", epoch=0, step=1, metric_data=MetricData("foo", 2.0) + ) + self.assertTrue(larger.more_optimal_than(smaller, mode="max")) + self.assertFalse(smaller.more_optimal_than(larger, mode="max")) + self.assertTrue(smaller.more_optimal_than(larger, mode="min")) + self.assertFalse(larger.more_optimal_than(smaller, mode="min")) + + +class CheckpointUtilsTest(unittest.TestCase): + @staticmethod + def _create_snapshot_metadata(output_dir: str) -> None: + path = os.path.join(output_dir, METADATA_FNAME) + with open(path, "w"): + pass + + def test_latest_checkpoint_path(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + self.assertIsNone(get_latest_checkpoint_path(temp_dir)) + + with tempfile.TemporaryDirectory() as temp_dir: + latest_path = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(latest_path) + self.assertEqual( + get_latest_checkpoint_path(temp_dir), + latest_path, + ) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + None, + ) + self._create_snapshot_metadata(latest_path) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), + latest_path, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + path_1 = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path_1) + self._create_snapshot_metadata(path_1) + path_2 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.002") + os.mkdir(path_2) + self._create_snapshot_metadata(path_2) + + # Missing metadata file + path_3 = os.path.join(temp_dir, "epoch_1_step_100") + os.mkdir(path_3) + + # Ill-formatted name + path_4 = os.path.join(temp_dir, "epoch_700") + os.mkdir(path_4) + self.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), path_2 + ) + + @skip_if_not_distributed + def test_latest_checkpoint_path_distributed(self) -> None: + spawn_multi_process( + 2, + "gloo", + self._latest_checkpoint_path_distributed, + ) + + @staticmethod + def _latest_checkpoint_path_distributed() -> None: + tc = unittest.TestCase() + is_rank0 = get_global_rank() == 0 + + if is_rank0: + temp_dir = tempfile.mkdtemp() + else: + temp_dir = "" + tc.assertIsNone(get_latest_checkpoint_path(temp_dir)) + if is_rank0: + shutil.rmtree(temp_dir) # delete temp directory + + if is_rank0: + temp_dir = tempfile.mkdtemp() + path_1 = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path_1) + CheckpointUtilsTest._create_snapshot_metadata(path_1) + path_2 = os.path.join(temp_dir, "epoch_0_step_100") + os.mkdir(path_2) + CheckpointUtilsTest._create_snapshot_metadata(path_2) + + # Missing metadata file + path_3 = os.path.join(temp_dir, "epoch_1_step_100") + os.mkdir(path_3) + + # Ill-formatted name + path_4 = os.path.join(temp_dir, "epoch_700") + os.mkdir(path_4) + else: + temp_dir = "" + path_2 = "" + + pg = PGWrapper(dist.group.WORLD) + path_container = [path_2] if is_rank0 else [None] + pg.broadcast_object_list(path_container, 0) + expected_path = path_container[0] + tc.assertIsNotNone(expected_path) + tc.assertEqual( + get_latest_checkpoint_path(temp_dir, METADATA_FNAME), expected_path + ) + + if is_rank0: + shutil.rmtree(temp_dir) # delete temp directory + + def test_best_checkpoint_path(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) + + # no checkpoint w/ metric value + path = os.path.join(temp_dir, "epoch_0_step_0") + os.mkdir(path) + self.assertIsNone(get_best_checkpoint_path(temp_dir, "val_loss", "min")) + + with tempfile.TemporaryDirectory() as temp_dir: + best_path = os.path.join(temp_dir, "epoch_0_step_0_val_loss=0.01") + os.mkdir(best_path) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min"), + best_path, + ) + self.assertIsNone( + get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), + None, + ) + self._create_snapshot_metadata(best_path) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min", METADATA_FNAME), + best_path, + ) + + # handle negative values + best_path_2 = os.path.join(temp_dir, "epoch_0_step_0_val_loss=-0.01") + os.mkdir(best_path_2) + self.assertEqual( + get_best_checkpoint_path(temp_dir, "val_loss", "min"), + best_path_2, + ) + + # handle "max" mode correctly + best_path_3 = os.path.join(temp_dir, "epoch_0_step_100_val_loss=0.1") + os.mkdir(best_path_3) + self.assertEqual( + get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), + best_path_3, + ) + + # handle different metric correctly + best_path_4 = os.path.join(temp_dir, "epoch_0_step_100_train_loss=0.2") + os.mkdir(best_path_4) + self.assertEqual( + get_best_checkpoint_path(temp_dir, metric_name="val_loss", mode="max"), + best_path_3, + ) + self.assertEqual( + get_best_checkpoint_path( + temp_dir, metric_name="train_loss", mode="max" + ), + best_path_4, + ) + + def test_retrieve_checkpoint_dirpaths(self) -> None: + """ + Tests retrieving checkpoint directories from a given root directory + """ + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + "epoch_0_step_10", + "epoch_1_step_10", + "epoch_2_step_10", + "epoch_0_step_5", + "epoch_0_step_6", + "epoch_0_step_3", + ] + for path in paths[:-1]: + os.mkdir(os.path.join(temp_dir, path)) + # make last path a file instead of a directory + with open(os.path.join(temp_dir, paths[-1]), "w"): + pass + + # compares set equality since order of returned dirpaths is not guaranteed + # in _retrieve_checkpoint_dirpaths + self.assertEqual( + set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + {os.path.join(temp_dir, path) for path in paths[:-1]}, + ) + self.assertEqual( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), + [], + ) + + # check metadata file is correct filtered for + # by creating metadata for 3rd path in list + with open(os.path.join(temp_dir, paths[2], ".metadata"), "w"): + pass + + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata") + ), + {os.path.join(temp_dir, paths[2])}, + ) + + def test_retrieve_checkpoint_dirpaths_with_metrics(self) -> None: + """ + Tests retrieving checkpoint (w/ metrics) directories from a given root directory + """ + with tempfile.TemporaryDirectory() as temp_dir: + paths = [ + "epoch_0_step_10_val_loss=10", + "epoch_1_step_10_val_loss=5", + "epoch_2_step_10", + "epoch_0_step_5", + "epoch_0_step_6_train_loss=13", + ] + for path in paths: + os.mkdir(os.path.join(temp_dir, path)) + # make last path a file instead of a directory + with open(os.path.join(temp_dir, "epoch_0_step_3_val_loss=3"), "w"): + pass + + # compares set equality since order of returned dirpaths is not guaranteed + # in _retrieve_checkpoint_dirpaths + self.assertEqual( + set(_retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=None)), + {os.path.join(temp_dir, path) for path in paths}, + ) + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=None, metric_name="val_loss" + ) + ), + { + os.path.join(temp_dir, path) for path in paths[:2] + }, # since last path is a file + ) + self.assertEqual( + _retrieve_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), + [], + ) + + # check metadata file is correct filtered for + # by creating metadata for 3rd path in list + with open(os.path.join(temp_dir, paths[1], ".metadata"), "w"): + pass + + self.assertEqual( + set( + _retrieve_checkpoint_dirpaths( + temp_dir, metadata_fname=".metadata", metric_name="val_loss" + ) + ), + {os.path.join(temp_dir, paths[1])}, + ) + + @skip_if_not_distributed + def test_distributed_get_checkpoint_dirpaths(self) -> None: + spawn_multi_process(2, "gloo", self._distributed_get_checkpoint_dirpaths) + + @staticmethod + def _distributed_get_checkpoint_dirpaths() -> None: + """ + Tests that existing checkpoint directories are read and + properly registered on all ranks + """ + + @rank_zero_read_and_broadcast + def create_tmp_dir() -> str: + return tempfile.mkdtemp() + + init_from_env() + + temp_dir = create_tmp_dir() + try: + path1 = os.path.join(temp_dir, "epoch_0_step_10") + path2 = os.path.join(temp_dir, "epoch_1_step_20") + if get_global_rank() == 0: + os.mkdir(path1) + os.mkdir(path2) + torch.distributed.barrier() + + ckpt_dirpaths = get_checkpoint_dirpaths(temp_dir) + tc = unittest.TestCase() + tc.assertEqual(set(ckpt_dirpaths), {path1, path2}) + + tc.assertEqual( + get_checkpoint_dirpaths(temp_dir, metadata_fname=".metadata"), [] + ) + finally: + if get_global_rank() == 0: + shutil.rmtree(temp_dir) # delete temp directory + + def test_get_checkpoint_dirpaths(self) -> None: + """ + Tests that `get_checkpoint_dirpaths` returns + the sorted checkpoint directories correctly + """ + with tempfile.TemporaryDirectory() as temp_dir: + path1 = os.path.join(temp_dir, "epoch_1_step_20") + path2 = os.path.join(temp_dir, "epoch_4_step_130") + path3 = os.path.join(temp_dir, "epoch_0_step_10") + os.mkdir(path1) + os.mkdir(path2) + os.mkdir(path3) + + self.assertEqual( + set(get_checkpoint_dirpaths(temp_dir)), + {path1, path2, path3}, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + path1 = os.path.join(temp_dir, "epoch_1_step_20_val_loss=0.01") + path2 = os.path.join(temp_dir, "epoch_4_step_130_val_loss=-0.2") + path3 = os.path.join(temp_dir, "epoch_0_step_10_val_loss=0.12") + os.mkdir(path1) + os.mkdir(path2) + os.mkdir(path3) + + self.assertEqual( + set(get_checkpoint_dirpaths(temp_dir, metric_name="val_loss")), + {path1, path2, path3}, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + self.assertEqual( + get_checkpoint_dirpaths(temp_dir), + [], + ) + + def test_checkpoint_sorting_utils(self) -> None: + """ + Tests the sort utilities + """ + paths = ["epoch_1_step_20", "epoch_4_step_130", "epoch_0_step_10_val_loss=10"] + self.assertEqual(_sort_by_recency(paths), [paths[2], paths[0], paths[1]]) + + paths = [ + "epoch_1_step_20_val_loss=0.09", + "epoch_4_step_130_val_loss=29", + "epoch_0_step_10_val_loss=10", + ] + self.assertEqual( + _sort_by_metric_value(paths, mode="min"), [paths[1], paths[2], paths[0]] + ) + self.assertEqual( + _sort_by_metric_value(paths, mode="max"), [paths[0], paths[2], paths[1]] + ) + + def test_delete_checkpoint(self) -> None: + """ + Tests removing checkpoint directories + """ + app_state = {"module": nn.Linear(2, 2)} + with tempfile.TemporaryDirectory() as temp_dir: + dirpath = os.path.join(temp_dir, "checkpoint") + Snapshot.take(dirpath, app_state=app_state) + self.assertTrue(os.path.exists(dirpath)) + # check that error is thrown if .snapshot_metadata is not found in the directory when deleting + os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) + with self.assertRaisesRegex( + RuntimeError, f"{temp_dir} does not contain .snapshot_metadata" + ): + _delete_checkpoint(temp_dir, SNAPSHOT_METADATA_FNAME) + _delete_checkpoint(dirpath) + self.assertFalse(os.path.exists(dirpath)) + + def test_metadata_exists(self) -> None: + app_state = {"module": nn.Linear(2, 2)} + with tempfile.TemporaryDirectory() as temp_dir: + dirpath = os.path.join(temp_dir, "checkpoint") + Snapshot.take(dirpath, app_state=app_state) + + fs = get_filesystem(dirpath) + self.assertTrue(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) + + os.remove(os.path.join(dirpath, SNAPSHOT_METADATA_FNAME)) + self.assertFalse(_metadata_exists(fs, dirpath, SNAPSHOT_METADATA_FNAME)) diff --git a/torchtnt/framework/callbacks/_checkpoint_utils.py b/torchtnt/framework/callbacks/_checkpoint_utils.py index 087c15c15b..674eb18fe5 100644 --- a/torchtnt/framework/callbacks/_checkpoint_utils.py +++ b/torchtnt/framework/callbacks/_checkpoint_utils.py @@ -6,268 +6,16 @@ # pyre-strict -import logging -import os -import re -from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar - -import fsspec +from typing import Any, Dict from pyre_extensions import none_throws -from torch import distributed as dist from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions from torchtnt.framework.state import State from torchtnt.framework.unit import AppStateMixin -from torchtnt.utils.distributed import rank_zero_read_and_broadcast -from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.stateful import Stateful -logger: logging.Logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -@rank_zero_read_and_broadcast -def get_latest_checkpoint_path( - dirpath: str, - metadata_fname: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> Optional[str]: - """ - Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory. - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - - Raises: - AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. - """ - - return _latest_checkpoint_path(dirpath, metadata_fname) - - -def _latest_checkpoint_path( - dirpath: str, metadata_fname: Optional[str] -) -> Optional[str]: - candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) - - # Initialize variables to store the largest epoch and step numbers - largest_subdirectory = None - largest_epoch = -1 - largest_step = -1 - - # Iterate through all files and directories in the specified directory - for candidate in candidate_dirpaths: - # Extract the epoch and step numbers from the directory name - dirname = os.path.basename(candidate) - - # dirname will be of the format epoch_N_step_M - # where N is the epoch number and M is the step number as integers - split = dirname.split("_") - if len(split) < 4: - raise AssertionError( - f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})" - ) - - epoch_num, step_num = int(split[1]), int(split[3]) - # Check if the current epoch and step numbers are larger than the largest ones found so far - if epoch_num > largest_epoch: - largest_epoch = epoch_num - largest_step = step_num - largest_subdirectory = dirname - elif largest_epoch == epoch_num and step_num > largest_step: - largest_step = step_num - largest_subdirectory = dirname - - if largest_subdirectory is None: - return None - - # Rejoin with the parent directory path and return the largest subdirectory - return os.path.join(dirpath, none_throws(largest_subdirectory)) - - -@rank_zero_read_and_broadcast -def get_best_checkpoint_path( - dirpath: str, - metric_name: str, - mode: Literal["min", "max"], - metadata_fname: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> Optional[str]: - """ - Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory. - - Args: - dirpath: parent directory where checkpoints are saved. - metric_name: Name of the metric to use to find the best checkpoint. - mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - """ - - dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) - if len(dirpaths) == 0: - # no checkpoints found - return None - - best_checkpoint_path = None - best_metric_value = float("inf") if mode == "min" else float("-inf") - for dirpath in dirpaths: - dirname = os.path.basename(dirpath) - metric_value = float(dirname.split("=")[-1]) - - if mode == "min": - if metric_value < best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath - else: - if metric_value > best_metric_value: - best_metric_value = metric_value - best_checkpoint_path = dirpath - - return best_checkpoint_path - - -@rank_zero_read_and_broadcast -def get_checkpoint_dirpaths( - dirpath: str, - metadata_fname: Optional[str] = None, - metric_name: Optional[str] = None, - process_group: Optional[dist.ProcessGroup] = None, -) -> List[str]: - """ - Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. - The order of the checkpoints is not guarenteed. - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - metric_name: fetches all the checkpoint directories containing the metric name only. - process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) - """ - - return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) - - -def _sort_by_recency(dirpaths: List[str]) -> List[str]: - """ - Sorts the given list of directories by oldest to newest. - - Args: - dirpaths: A list of directory paths. - - Returns: - A sorted list of directory paths, sorted by recency. - """ - - def sort_fn(path: str) -> Tuple[int, int]: - x = os.path.basename(path) - return (int(x.split("_")[1]), int(x.split("_")[3])) - - return sorted(dirpaths, key=sort_fn) - - -def _sort_by_metric_value( - dirpaths: List[str], mode: Literal["min", "max"] -) -> List[str]: - """ - Sorts the given list of directories by the metric values. - - Args: - dirpaths: A list of directory paths. - mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order - - Returns: - A sorted list of directory paths, sorted by the metric values. - """ - - def sort_metric_fn(path: str) -> float: - x = os.path.basename(path) - metric_val = float(x.split("=")[-1]) - return metric_val - - return sorted( - dirpaths, - key=sort_metric_fn, - # sort descending if min, placing worst metric at top of list - reverse=(mode == "min"), - ) - - -def _retrieve_checkpoint_dirpaths( - dirpath: str, - metadata_fname: Optional[str], - metric_name: Optional[str] = None, -) -> List[str]: - """ - Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories - - Args: - dirpath: parent directory where checkpoints are saved. - metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. - metric_name: Name of the metric that must exist in checkpoint name. - """ - - if dirpath[-1] == "/": - # removes trailing forward slash if present - # required for regex search to work - dirpath = dirpath[:-1] - - fs = get_filesystem(dirpath) - - if not fs.exists(dirpath): - logger.warning(f"Input dirpath doesn't exist: {dirpath}") - return [] - - contents = fs.ls(dirpath, detail=True) - contents = [item["name"] for item in contents if item["type"] == "directory"] - if len(contents) == 0: - logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}") - return [] - - # Define the regex pattern to match the directory names - pattern = rf"^{dirpath}/epoch_\d+_step_\d+" - if metric_name: - # inject metric name in regex search - pattern += rf"_{metric_name}=" - snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern) - candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents)) - - if not metadata_fname: - # return early as we don't need to filter out any paths - return candidate_dirpaths - - # Iterate through all files and directories in the specified directory - # and check if metedata is present or not - valid_ckpt_dirpaths = [] - for candidate in candidate_dirpaths: - if not _metadata_exists(fs, candidate, metadata_fname): - logger.warning( - f"Snapshot metadata is missing from {candidate}! Skipping this path" - ) - continue - - valid_ckpt_dirpaths.append(candidate) - - return valid_ckpt_dirpaths - - -def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None: - fs = get_filesystem(dirpath) - if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname): - raise RuntimeError(f"{dirpath} does not contain {metadata_fname}") - fs.rm(dirpath, recursive=True) - - -def _metadata_exists( - fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str -) -> bool: - return fs.exists(os.path.join(dirpath, metadata_fname)) - # keys for use when checkpointing _TRAIN_PROGRESS_STATE_KEY = "train_progress" diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index f5f56fdf25..ded257c412 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -16,7 +16,14 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks._checkpoint_utils import ( +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + RestoreOptions, +) +from torchtnt.framework.state import EntryPoint, State +from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit +from torchtnt.framework.utils import get_timing_context +from torchtnt.utils.checkpoint import ( _delete_checkpoint, _metadata_exists, _sort_by_metric_value, @@ -25,13 +32,6 @@ get_checkpoint_dirpaths, get_latest_checkpoint_path, ) -from torchtnt.framework.callbacks.checkpointer_types import ( - BestCheckpointConfig, - RestoreOptions, -) -from torchtnt.framework.state import EntryPoint, State -from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit -from torchtnt.framework.utils import get_timing_context from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast from torchtnt.utils.fsspec import get_filesystem from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index cb973c13a6..06cb6d33da 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -6,6 +6,13 @@ # pyre-strict +from .checkpoint import ( + CheckpointPath, + get_best_checkpoint_path, + get_checkpoint_dirpaths, + get_latest_checkpoint_path, + MetricData, +) from .device import ( copy_data_to_device, CPUStats, @@ -148,4 +155,9 @@ "is_windows", "get_pet_launch_config", "spawn_multi_process", + "CheckpointPath", + "MetricData", + "get_best_checkpoint_path", + "get_checkpoint_dirpaths", + "get_latest_checkpoint_path", ] diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py new file mode 100644 index 0000000000..d5ddc4b2f6 --- /dev/null +++ b/torchtnt/utils/checkpoint.py @@ -0,0 +1,425 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import logging +import os +import re +from dataclasses import dataclass +from functools import total_ordering +from typing import List, Literal, Optional, Pattern, Tuple + +import fsspec +import torch.distributed as dist +from fsspec.core import url_to_fs +from pyre_extensions import none_throws +from torchtnt.utils.distributed import rank_zero_read_and_broadcast + +logger: logging.Logger = logging.getLogger(__name__) + + +@dataclass +class MetricData: + """ + Representation of a metric instance. Should provide both a metric name and it's value. + """ + + name: str + value: float + + +@total_ordering +class CheckpointPath: + """ + Representation of a checkpoint path. Handles parsing and serialization of the specific path format. + Currently, the basic compliant path format is: /epoch__step_ + If a metric is being tracked, it's added to the name: /epoch__step__= + + This class is well-ordered by checkpoint recency, so any comparisons will operate using the epoch + step. Sorting by + metric can be done by extracting the metric value from the metric_data attribute. + """ + + PATH_REGEX: Pattern = re.compile( + r"^(.+)epoch_(\d+)_step_(\d+)(?:_(.+)=(-?\d+\.?\d*))?\/?$" + ) + + def __init__( + self, + dirpath: str, + epoch: int, + step: int, + metric_data: Optional[MetricData] = None, + ) -> None: + """ + Args: + dirpath: The base directory path that checkpoints are saved in. + epoch: The epoch number of this checkpoint. + step: The step number of this checkpoint. + metric_data: Optional data about the metric being tracked. Should contain both metric name and value. + """ + self.dirpath: str = dirpath.rstrip("/") + self.epoch = epoch + self.step = step + self.metric_data = metric_data + + @classmethod + def from_str(cls, checkpoint_path: str) -> "CheckpointPath": + """ + Given a directory path, try to parse it and extract the checkpoint data. + The expected format is: /epoch__step__=, + where the metric name and value are optional. + + Args: + checkpoint_path: The path to the checkpoint directory. + + Returns: + A CheckpointPath instance if the path is valid, otherwise None. + + Raises: + ValueError: If the path is malformed and can't be parsed. + """ + path_match = cls.PATH_REGEX.match(checkpoint_path) + if not path_match: + raise ValueError( + f"Attempted to parse malformed checkpoint path: {checkpoint_path}." + ) + + dirpath, epoch, step, metric_name, metric_value = path_match.groups() + try: + metric_data: Optional[MetricData] = None + if metric_name: + metric_value_f = float(metric_value) + metric_data = MetricData(name=metric_name, value=metric_value_f) + + return CheckpointPath( + dirpath=dirpath, + epoch=int(epoch), + step=int(step), + metric_data=metric_data, + ) + + except ValueError: + # Should never happen since path matches regex + raise ValueError( + f"Invalid data types found in checkpoint path: {checkpoint_path}." + ) + + @property + def path(self) -> str: + """ + Returns: + The full path to the checkpoint directory. + """ + name = f"epoch_{self.epoch}_step_{self.step}" + if self.metric_data: + name += f"_{self.metric_data.name}={self.metric_data.value}" + + return os.path.join(self.dirpath, name) + + def newer_than(self, other: "CheckpointPath") -> bool: + """ + Given another CheckpointPath instance, determine if this checkpoint is strictly newer than the other. + + Returns: + True if this checkpoint is newer than the other, otherwise False. + """ + if self.epoch != other.epoch: + return self.epoch > other.epoch + + return self.step > other.step + + def more_optimal_than( + self, other: "CheckpointPath", mode: Literal["min", "max"] + ) -> bool: + """ + Given another CheckpointPath instance, determine if this checkpoint is strictly more optimal than the other. + Optimality is determined by comparing the metric value of the two checkpoints. The mode indicates if the + metric value should be minimized or maximized. This only works for metric-aware checkpoints. + + Args: + other: The other checkpoint path to compare against. + mode: The mode to use for comparison. + + Returns: + True if this checkpoint is more optimal than the other, otherwise False. + + Note: This expects that both checkpoints are metric-aware, and that they are tracking the same metric. + """ + + assert ( + self.metric_data and other.metric_data + ), f"Attempted to compare optimality of non metric-aware checkpoints: {self} and {other}" + + assert ( + self.metric_data.name == other.metric_data.name + ), f"Attempted to compare optimality of checkpoints tracking different metrics: {self} and {other}" + + if mode == "min": + return ( + none_throws(self.metric_data).value + < none_throws(other.metric_data).value + ) + + return ( + none_throws(self.metric_data).value > none_throws(other.metric_data).value + ) + + def __str__(self) -> str: + return self.path + + def __repr__(self) -> str: + return f"CheckpointPath(dirpath={self.dirpath}, epoch={self.epoch}, step={self.step}, metric_data={self.metric_data})" + + def __eq__(self, other: "CheckpointPath") -> bool: + return ( + self.dirpath == other.dirpath + and self.epoch == other.epoch + and self.step == other.step + and self.metric_data == other.metric_data + ) + + def __gt__(self, other: "CheckpointPath") -> bool: + return self.newer_than(other) + + +@rank_zero_read_and_broadcast +def get_latest_checkpoint_path( + dirpath: str, + metadata_fname: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Optional[str]: + """ + Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory. + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + + Raises: + AssertionError if the checkpoint subdirectories are not named in the format epoch_{epoch}_step_{step}. + """ + + return _latest_checkpoint_path(dirpath, metadata_fname) + + +def _latest_checkpoint_path( + dirpath: str, metadata_fname: Optional[str] +) -> Optional[str]: + candidate_dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname) + + # Initialize variables to store the largest epoch and step numbers + largest_subdirectory = None + largest_epoch = -1 + largest_step = -1 + + # Iterate through all files and directories in the specified directory + for candidate in candidate_dirpaths: + # Extract the epoch and step numbers from the directory name + dirname = os.path.basename(candidate) + + # dirname will be of the format epoch_N_step_M + # where N is the epoch number and M is the step number as integers + split = dirname.split("_") + if len(split) < 4: + raise AssertionError( + f"Expected 4 or more elements for pattern of epoch_N_step_M, but received {split})" + ) + + epoch_num, step_num = int(split[1]), int(split[3]) + # Check if the current epoch and step numbers are larger than the largest ones found so far + if epoch_num > largest_epoch: + largest_epoch = epoch_num + largest_step = step_num + largest_subdirectory = dirname + elif largest_epoch == epoch_num and step_num > largest_step: + largest_step = step_num + largest_subdirectory = dirname + + if largest_subdirectory is None: + return None + + # Rejoin with the parent directory path and return the largest subdirectory + return os.path.join(dirpath, none_throws(largest_subdirectory)) + + +@rank_zero_read_and_broadcast +def get_best_checkpoint_path( + dirpath: str, + metric_name: str, + mode: Literal["min", "max"], + metadata_fname: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> Optional[str]: + """ + Given a parent directory where checkpoints are saved, return the best checkpoint subdirectory. + + Args: + dirpath: parent directory where checkpoints are saved. + metric_name: Name of the metric to use to find the best checkpoint. + mode: Either 'min' or 'max'. If 'min', finds and loads the lowest value metric checkpoint. If 'max', finds and loads the largest. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + """ + + dirpaths = _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + if len(dirpaths) == 0: + # no checkpoints found + return None + + best_checkpoint_path = None + best_metric_value = float("inf") if mode == "min" else float("-inf") + for dirpath in dirpaths: + dirname = os.path.basename(dirpath) + metric_value = float(dirname.split("=")[-1]) + + if mode == "min": + if metric_value < best_metric_value: + best_metric_value = metric_value + best_checkpoint_path = dirpath + else: + if metric_value > best_metric_value: + best_metric_value = metric_value + best_checkpoint_path = dirpath + + return best_checkpoint_path + + +@rank_zero_read_and_broadcast +def get_checkpoint_dirpaths( + dirpath: str, + metadata_fname: Optional[str] = None, + metric_name: Optional[str] = None, + process_group: Optional[dist.ProcessGroup] = None, +) -> List[str]: + """ + Given a parent directory where checkpoints are saved, returns the checkpoint subdirectories. + The order of the checkpoints is not guarenteed. + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + metric_name: fetches all the checkpoint directories containing the metric name only. + process_group: the process group on which the ranks will communicate on. default: ``None`` (the entire world) + """ + + return _retrieve_checkpoint_dirpaths(dirpath, metadata_fname, metric_name) + + +def _sort_by_recency(dirpaths: List[str]) -> List[str]: + """ + Sorts the given list of directories by oldest to newest. + + Args: + dirpaths: A list of directory paths. + + Returns: + A sorted list of directory paths, sorted by recency. + """ + + def sort_fn(path: str) -> Tuple[int, int]: + x = os.path.basename(path) + return (int(x.split("_")[1]), int(x.split("_")[3])) + + return sorted(dirpaths, key=sort_fn) + + +def _sort_by_metric_value( + dirpaths: List[str], mode: Literal["min", "max"] +) -> List[str]: + """ + Sorts the given list of directories by the metric values. + + Args: + dirpaths: A list of directory paths. + mode: Either 'min' or 'max'. If 'min', sorts in descending order. If 'max', sorts in ascending order + + Returns: + A sorted list of directory paths, sorted by the metric values. + """ + + def sort_metric_fn(path: str) -> float: + x = os.path.basename(path) + metric_val = float(x.split("=")[-1]) + return metric_val + + return sorted( + dirpaths, + key=sort_metric_fn, + # sort descending if min, placing worst metric at top of list + reverse=(mode == "min"), + ) + + +def _retrieve_checkpoint_dirpaths( + dirpath: str, + metadata_fname: Optional[str], + metric_name: Optional[str] = None, +) -> List[str]: + """ + Given a parent directory where checkpoints are saved, return the unsorted checkpoint subdirectories + + Args: + dirpath: parent directory where checkpoints are saved. + metadata_fname: Checks if metadata file is present in checkpoint, disregards if it does not exist. + metric_name: Name of the metric that must exist in checkpoint name. + """ + + if dirpath[-1] == "/": + # removes trailing forward slash if present + # required for regex search to work + dirpath = dirpath[:-1] + + fs, _ = url_to_fs(dirpath) + + if not fs.exists(dirpath): + logger.warning(f"Input dirpath doesn't exist: {dirpath}") + return [] + + contents = fs.ls(dirpath, detail=True) + contents = [item["name"] for item in contents if item["type"] == "directory"] + if len(contents) == 0: + logger.warning(f"Input dirpath doesn't contain any subdirectories: {dirpath}") + return [] + + # Define the regex pattern to match the directory names + pattern = rf"^{dirpath}/epoch_\d+_step_\d+" + if metric_name: + # inject metric name in regex search + pattern += rf"_{metric_name}=" + snapshot_dirpath_pattern: Pattern[str] = re.compile(pattern) + candidate_dirpaths = list(filter(snapshot_dirpath_pattern.match, contents)) + + if not metadata_fname: + # return early as we don't need to filter out any paths + return candidate_dirpaths + + # Iterate through all files and directories in the specified directory + # and check if metedata is present or not + valid_ckpt_dirpaths = [] + for candidate in candidate_dirpaths: + if not _metadata_exists(fs, candidate, metadata_fname): + logger.warning( + f"Snapshot metadata is missing from {candidate}! Skipping this path" + ) + continue + + valid_ckpt_dirpaths.append(candidate) + + return valid_ckpt_dirpaths + + +def _delete_checkpoint(dirpath: str, metadata_fname: Optional[str] = None) -> None: + fs, _ = url_to_fs(dirpath) + if metadata_fname and not _metadata_exists(fs, dirpath, metadata_fname): + raise RuntimeError(f"{dirpath} does not contain {metadata_fname}") + fs.rm(dirpath, recursive=True) + + +def _metadata_exists( + fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str +) -> bool: + return fs.exists(os.path.join(dirpath, metadata_fname))