diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index c17b874c21..29468e66c3 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -20,6 +20,8 @@ FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath, + RemoteFilesExistingCheckStatus, + RemoteUploader, checkpoint, create_interval_scheduler, create_symlink_file, @@ -28,6 +30,7 @@ format_name_with_dist, format_name_with_dist_and_time, is_model_deepspeed, + parse_uri, partial_format, ) from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME @@ -287,8 +290,13 @@ def __init__( num_checkpoints_to_keep: int = -1, weights_only: bool = False, ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None, + num_concurrent_uploads: int = 1, + upload_timeout_in_seconds: int = 3600, ): - folder = str(folder) + backend, _, local_folder = parse_uri(str(folder)) + if local_folder == '': + local_folder = '.' + filename = str(filename) remote_file_name = str(remote_file_name) if remote_file_name is not None else None latest_filename = str(latest_filename) if latest_filename is not None else None @@ -304,10 +312,10 @@ def __init__( self.save_interval = save_interval self.last_checkpoint_batch: Optional[Time] = None - self.folder = folder + self.folder = local_folder - self.filename = PartialFilePath(filename.lstrip('/'), folder) - self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), folder) if latest_filename else None + self.filename = PartialFilePath(filename.lstrip('/'), local_folder) + self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), local_folder) if latest_filename else None self.remote_file_name = PartialFilePath(remote_file_name) if remote_file_name else None self.latest_remote_file_name = PartialFilePath(latest_remote_file_name) if latest_remote_file_name else None @@ -320,6 +328,23 @@ def __init__( self.start_batch = None + self.remote_uploader = None + self.rank_saves_symlinks: bool = False + self.tmp_dir_for_symlink = tempfile.TemporaryDirectory() + self.num_concurrent_uploads = num_concurrent_uploads + self.upload_timeout_in_seconds = upload_timeout_in_seconds + # Allow unit test to override this to make it faster + self._symlink_upload_wait_before_next_try_in_seconds = 30.0 + self.pid = os.getpid() + self.symlink_count = 0 + self.symlink_upload_tasks = [] + + if backend != '': + self.remote_uploader = RemoteUploader( + remote_folder=str(folder), + num_concurrent_uploads=self.num_concurrent_uploads, + ) + def init(self, state: State, logger: Logger) -> None: # If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths. # Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers. @@ -346,9 +371,10 @@ def init(self, state: State, logger: Logger) -> None: self.latest_remote_file_name.filename, **mlflow_format_kwargs, ) - break + if self.remote_uploader is not None: + self.remote_uploader.init() folder = format_name_with_dist(self.folder, state.run_name) os.makedirs(folder, exist_ok=True) @@ -410,6 +436,27 @@ def load_state_dict(self, state: dict[str, Any]): load_timestamp.load_state_dict(timestamp_state) self.all_saved_checkpoints_to_timestamp[save_filename] = load_timestamp + def _upload_checkpoint( + self, + remote_file_name: str, + local_file_name: str, + local_remote_file_names: list[str], + logger: Logger, + ): + if self.remote_uploader is not None: + self.remote_uploader.upload_file_async( + remote_file_name=remote_file_name, + file_path=pathlib.Path(local_file_name), + overwrite=self.overwrite, + ) + local_remote_file_names.append(remote_file_name) + else: + logger.upload_file( + remote_file_name=remote_file_name, + file_path=local_file_name, + overwrite=self.overwrite, + ) + def _save_checkpoint(self, state: State, logger: Logger): self.last_checkpoint_batch = state.timestamp.batch @@ -432,7 +479,14 @@ def _save_checkpoint(self, state: State, logger: Logger): ) log.debug(f'Checkpoint locally saved to {saved_path}') + self.symlink_count += 1 + # Remote checkpoint file names on this rank + local_remote_file_names = [] + all_remote_filenames = [] + if not saved_path: # not all ranks save + if self.remote_file_name is not None and self.remote_uploader is not None: + all_remote_filenames = dist.all_gather_object(local_remote_file_names) return metadata_local_file_path = None @@ -443,6 +497,7 @@ def _save_checkpoint(self, state: State, logger: Logger): state.timestamp, ) + self.rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled if self.latest_filename is not None and self.num_checkpoints_to_keep != 0: symlink = self.latest_filename.format(state, is_deepspeed) os.makedirs(os.path.dirname(symlink), exist_ok=True) @@ -455,8 +510,7 @@ def _save_checkpoint(self, state: State, logger: Logger): src_path = str(pathlib.Path(saved_path).parent) else: src_path = saved_path - this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled - if this_rank_saves_symlinks: + if self.rank_saves_symlinks: os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink) # if remote file name provided, upload the checkpoint @@ -482,10 +536,11 @@ def _save_checkpoint(self, state: State, logger: Logger): state.timestamp, ) assert metadata_local_file_path is not None - logger.upload_file( + self._upload_checkpoint( remote_file_name=metadata_remote_file_name, - file_path=metadata_local_file_path, - overwrite=self.overwrite, + local_file_name=metadata_local_file_path, + local_remote_file_names=local_remote_file_names, + logger=logger, ) else: remote_file_name = self.remote_file_name.format( @@ -495,12 +550,20 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug(f'Uploading checkpoint to {remote_file_name}') try: - logger.upload_file(remote_file_name=remote_file_name, file_path=saved_path, overwrite=self.overwrite) + self._upload_checkpoint( + remote_file_name=remote_file_name, + local_file_name=saved_path, + local_remote_file_names=local_remote_file_names, + logger=logger, + ) except FileExistsError as e: raise FileExistsError( f'Uploading checkpoint failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite checkpoints with Trainer, set save_overwrite to True.', ) from e + if self.remote_uploader is not None: + all_remote_filenames = dist.all_gather_object(local_remote_file_names) + # symlinks stay the same with sharded checkpointing if self.latest_remote_file_name is not None: symlink_name = self.latest_remote_file_name.format( @@ -509,17 +572,31 @@ def _save_checkpoint(self, state: State, logger: Logger): ).lstrip('/') + '.symlink' # create and upload a symlink file - with tempfile.TemporaryDirectory() as tmpdir: - symlink_filename = os.path.join(tmpdir, 'latest.symlink') - # Sharded checkpoints for torch >2.0 use directories not files for load_paths - if state.fsdp_sharded_state_dict_enabled: - src_path = str(pathlib.Path(remote_file_name).parent) + symlink_filename = os.path.join( + self.tmp_dir_for_symlink.name, + f'latest.{self.symlink_count}.symlink', + ) + # Sharded checkpoints for torch >2.0 use directories not files for load_paths + if state.fsdp_sharded_state_dict_enabled: + src_path = str(pathlib.Path(remote_file_name).parent) + else: + src_path = remote_file_name + log.debug(f'Creating symlink file {symlink_filename} -> {src_path}') + if self.rank_saves_symlinks: + create_symlink_file(src_path, symlink_filename) + if self.remote_uploader is not None: + remote_checkpoint_file_names = [] + for file_names in all_remote_filenames: + remote_checkpoint_file_names += file_names + check_remote_files_exist_future = self.remote_uploader.check_remote_files_exist_async( + remote_checkpoint_file_names=remote_checkpoint_file_names, + max_wait_time_in_seconds=self.upload_timeout_in_seconds, + wait_before_next_try_in_seconds=self._symlink_upload_wait_before_next_try_in_seconds, + ) + self.symlink_upload_tasks.append( + (check_remote_files_exist_future, symlink_filename, symlink_name), + ) else: - src_path = remote_file_name - log.debug(f'Creating symlink file {symlink_filename} -> {src_path}') - this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled - if this_rank_saves_symlinks: - create_symlink_file(src_path, symlink_filename) logger.upload_file( remote_file_name=symlink_name, file_path=symlink_filename, @@ -532,7 +609,6 @@ def _save_checkpoint(self, state: State, logger: Logger): self._rotate_checkpoints(sharding_enabled=state.fsdp_sharded_state_dict_enabled) def _rotate_checkpoints(self, sharding_enabled: bool = False): - while len(self.saved_checkpoints) > self.num_checkpoints_to_keep: prefix_dir = None checkpoint_to_delete = self.saved_checkpoints.pop(0) @@ -542,3 +618,62 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False): else: if dist.get_global_rank() == 0: shutil.rmtree(prefix_dir) + + def batch_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + if self.remote_uploader is None: + return + self.remote_uploader.check_workers() + if not self.rank_saves_symlinks: + return + undone_symlink_upload_tasks = [] + for (check_remote_files_exist_future, local_symlink_file, + remote_symlink_file) in reversed(self.symlink_upload_tasks): + if not check_remote_files_exist_future.done(): + undone_symlink_upload_tasks.insert( + 0, + (check_remote_files_exist_future, local_symlink_file, remote_symlink_file), + ) + continue + if check_remote_files_exist_future.done(): + result = check_remote_files_exist_future.result() + if result == RemoteFilesExistingCheckStatus.EXIST: + self.remote_uploader.upload_file_async( + remote_file_name=remote_symlink_file, + file_path=local_symlink_file, + overwrite=True, + ) + break + else: + raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') + self.symlink_upload_tasks = undone_symlink_upload_tasks + + def fit_end(self, state: State, logger: Logger) -> None: + del state, logger # unused + if self.remote_uploader is None: + return + log.info('Waiting for checkpoint uploading to finish') + self.remote_uploader.wait() + if self.rank_saves_symlinks and len(self.symlink_upload_tasks) > 0: + log.debug('Uploading symlink to the latest checkpoint') + # We only need to upload a symlink pointing to the latest checkpoint files, so we can ignore successful uploads of older checkpoints. + check_remote_files_exist_future, local_symlink_file, remote_symlink_file = self.symlink_upload_tasks[-1] + result = check_remote_files_exist_future.result() + if result == RemoteFilesExistingCheckStatus.EXIST: + symlink_upload_future = self.remote_uploader.upload_file_async( + remote_file_name=remote_symlink_file, + file_path=local_symlink_file, + overwrite=True, + ) + symlink_upload_future.result() + else: + raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}') + log.info('Checkpoint uploading finished!') + + def post_close(self): + if self.remote_uploader is not None: + # Wait the symlink file upload to finish and close remote uploader + try: + self.remote_uploader.wait_and_close() + except Exception as e: + log.error(f'RemoteUploader run into exception {e}') diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 981cc4c650..9378d5a8d4 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -25,19 +25,15 @@ from composer.loggers import Logger, MosaicMLLogger from composer.loggers.logger_destination import LoggerDestination from composer.utils import ( - GCSObjectStore, - LibcloudObjectStore, MLFlowObjectStore, ObjectStore, ObjectStoreTransientError, - OCIObjectStore, - S3ObjectStore, - SFTPObjectStore, - UCObjectStore, + build_remote_backend, dist, format_name_with_dist, get_file, retry, + validate_credentials, ) from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX @@ -50,37 +46,6 @@ __all__ = ['RemoteUploaderDownloader'] -def _build_remote_backend(remote_backend_name: str, backend_kwargs: dict[str, Any]): - remote_backend_cls = None - remote_backend_name_to_cls = { - 's3': S3ObjectStore, - 'oci': OCIObjectStore, - 'sftp': SFTPObjectStore, - 'libcloud': LibcloudObjectStore, - 'gs': GCSObjectStore, - } - - # Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore` - # or :class:`.MLFlowObjectStore`. - if remote_backend_name == 'dbfs': - path = backend_kwargs['path'] - if path.startswith(MLFLOW_DBFS_PATH_PREFIX): - remote_backend_cls = MLFlowObjectStore - else: - # Validate if the path conforms to the requirements for UC volume paths - UCObjectStore.validate_path(path) - remote_backend_cls = UCObjectStore - else: - remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None) - if remote_backend_cls is None: - supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs'] - raise ValueError( - f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})', - ) - - return remote_backend_cls(**backend_kwargs) - - class RemoteUploaderDownloader(LoggerDestination): r"""Logger destination that uploads (downloads) files to (from) a remote backend. @@ -339,7 +304,7 @@ def __init__( def remote_backend(self) -> ObjectStore: """The :class:`.ObjectStore` instance for the main thread.""" if self._remote_backend is None: - self._remote_backend = _build_remote_backend(self.remote_backend_name, self.backend_kwargs) + self._remote_backend = build_remote_backend(self.remote_backend_name, self.backend_kwargs) return self._remote_backend def init(self, state: State, logger: Logger) -> None: @@ -359,7 +324,7 @@ def init(self, state: State, logger: Logger) -> None: retry( ObjectStoreTransientError, self.num_attempts, - )(lambda: _validate_credentials(self.remote_backend, file_name_to_test))() + )(lambda: validate_credentials(self.remote_backend, file_name_to_test))() # If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be # updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information @@ -635,20 +600,6 @@ def _remote_file_name(self, remote_file_name: str): return key_name -def _validate_credentials( - remote_backend: ObjectStore, - remote_file_name_to_test: str, -) -> None: - # Validates the credentials by attempting to touch a file in the bucket - # raises an error if there was a credentials failure. - with tempfile.NamedTemporaryFile('wb') as f: - f.write(b'credentials_validated_successfully') - remote_backend.upload_object( - object_name=remote_file_name_to_test, - filename=f.name, - ) - - def _upload_worker( file_queue: Union[queue.Queue[tuple[str, str, bool]], multiprocessing.JoinableQueue[tuple[str, str, bool]]], completed_queue: Union[queue.Queue[str], multiprocessing.JoinableQueue[str]], @@ -663,7 +614,7 @@ def _upload_worker( The worker will continuously poll ``file_queue`` for files to upload. Once ``is_finished`` is set, the worker will exit once ``file_queue`` is empty. """ - remote_backend = _build_remote_backend(remote_backend_name, backend_kwargs) + remote_backend = build_remote_backend(remote_backend_name, backend_kwargs) while True: try: file_path_to_upload, remote_file_name, overwrite = file_queue.get(block=True, timeout=0.5) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index f5a6b57d77..c752187ba6 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1387,16 +1387,6 @@ def __init__( mosaicml_logger = MosaicMLLogger() loggers.append(mosaicml_logger) - # Remote Uploader Downloader - # Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before - # the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a - # run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize - # the active MLFlow run. - if save_folder is not None: - remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers) - if remote_ud is not None: - loggers.append(remote_ud) - # Logger self.logger = Logger(state=self.state, destinations=loggers) @@ -1451,14 +1441,12 @@ def __init__( # path then we assume they just want their checkpoints saved directly in their # bucket. if parsed_save_folder == '': - folder = '.' remote_file_name = save_filename latest_remote_file_name = save_latest_filename # If they actually specify a path, then we use that for their local save path # and we prefix save_filename with that path for remote_file_name. else: - folder = parsed_save_folder remote_file_name = str(Path(parsed_save_folder) / Path(save_filename)) if save_latest_filename is not None: latest_remote_file_name = str(Path(parsed_save_folder) / Path(save_latest_filename)) @@ -1466,7 +1454,7 @@ def __init__( latest_remote_file_name = None self._checkpoint_saver = CheckpointSaver( - folder=folder, + folder=save_folder, filename=save_filename, remote_file_name=remote_file_name, latest_filename=save_latest_filename, @@ -1889,14 +1877,17 @@ def _try_checkpoint_download( self, latest_checkpoint_path: str, save_latest_remote_file_name: str, - loggers: Sequence[LoggerDestination], + loggers: Sequence[Union[LoggerDestination, ObjectStore]], load_progress_bar: bool, ) -> None: """Attempts to download the checkpoint from the logger destinations.""" log.debug( f'Trying to download {save_latest_remote_file_name} to {latest_checkpoint_path} on rank {dist.get_global_rank()}', ) - for logger in loggers: + remote_destination = list(loggers) + if self._checkpoint_saver is not None and self._checkpoint_saver.remote_uploader is not None: + remote_destination.append(self._checkpoint_saver.remote_uploader.remote_backend) + for logger in remote_destination: try: # Fetch from logger. If it succeeds, stop trying the rest of the loggers get_file( @@ -1938,7 +1929,7 @@ def _get_autoresume_checkpoint( f'Looking for autoresume checkpoint: {save_latest_remote_file_name} (remote), {latest_checkpoint_path} (local)', ) - if self.state.deepspeed_enabled or self.state.fsdp_sharded_state_dict_enabled: + if self.state.deepspeed_enabled: # If latest checkpoint is not saved locally, try to fetch from loggers if not os.path.exists(latest_checkpoint_path): log.debug(f'Attempting to download the checkpoint on to rank {dist.get_global_rank()}') diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index f04da5c0e8..0850fd2bdd 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -44,6 +44,7 @@ maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri, parse_uri, + validate_credentials, ) from composer.utils.import_helpers import MissingConditionalImportError, import_object from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic @@ -72,8 +73,10 @@ S3ObjectStore, SFTPObjectStore, UCObjectStore, + build_remote_backend, ) from composer.utils.parallelism import FSDPConfig, ParallelismConfig, TPConfig, create_fsdp_config +from composer.utils.remote_uploader import RemoteFilesExistingCheckStatus, RemoteUploader from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum from composer.utils.warnings import VersionedDeprecationWarning @@ -155,4 +158,8 @@ 'ParallelismConfig', 'MLFLOW_EXPERIMENT_ID_FORMAT_KEY', 'MLFLOW_RUN_ID_FORMAT_KEY', + 'RemoteUploader', + 'validate_credentials', + 'build_remote_backend', + 'RemoteFilesExistingCheckStatus', ] diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index 2d14cc27ea..11d10328ea 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -49,6 +49,7 @@ 'maybe_create_object_store_from_uri', 'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri', + 'validate_credentials', ] @@ -737,3 +738,18 @@ def create_symlink_file( raise ValueError('The symlink filename must end with .symlink.') with open(destination_filename, 'x') as f: f.write(existing_path) + + +def validate_credentials( + remote_backend: ObjectStore, + remote_file_name_to_test: str, +): + """Upload a tiny text file to test if the credentials are setup correctly.""" + # Validates the credentials by attempting to touch a file in the bucket + # raises an error if there was a credentials failure. + with tempfile.NamedTemporaryFile('wb') as f: + f.write(b'credentials_validated_successfully') + remote_backend.upload_object( + object_name=remote_file_name_to_test, + filename=f.name, + ) diff --git a/composer/utils/object_store/__init__.py b/composer/utils/object_store/__init__.py index 3c70257e08..6171013c2c 100644 --- a/composer/utils/object_store/__init__.py +++ b/composer/utils/object_store/__init__.py @@ -15,6 +15,7 @@ from composer.utils.object_store.s3_object_store import S3ObjectStore from composer.utils.object_store.sftp_object_store import SFTPObjectStore from composer.utils.object_store.uc_object_store import UCObjectStore +from composer.utils.object_store.utils import build_remote_backend __all__ = [ 'ObjectStore', @@ -28,4 +29,5 @@ 'UCObjectStore', 'MLFLOW_EXPERIMENT_ID_FORMAT_KEY', 'MLFLOW_RUN_ID_FORMAT_KEY', + 'build_remote_backend', ] diff --git a/composer/utils/object_store/utils.py b/composer/utils/object_store/utils.py new file mode 100644 index 0000000000..0d33774bc7 --- /dev/null +++ b/composer/utils/object_store/utils.py @@ -0,0 +1,48 @@ +# Copyright 2024 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Helpers for working with object stores.""" + +from typing import Any + +from composer.utils.object_store.gcs_object_store import GCSObjectStore +from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore +from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX, MLFlowObjectStore +from composer.utils.object_store.oci_object_store import OCIObjectStore +from composer.utils.object_store.s3_object_store import S3ObjectStore +from composer.utils.object_store.sftp_object_store import SFTPObjectStore +from composer.utils.object_store.uc_object_store import UCObjectStore + +__all__ = ['build_remote_backend'] + + +def build_remote_backend(remote_backend_name: str, backend_kwargs: dict[str, Any]): + """Build object store given the backend name and kwargs.""" + remote_backend_cls = None + remote_backend_name_to_cls = { + 's3': S3ObjectStore, + 'oci': OCIObjectStore, + 'sftp': SFTPObjectStore, + 'libcloud': LibcloudObjectStore, + 'gs': GCSObjectStore, + } + + # Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore` + # or :class:`.MLFlowObjectStore`. + if remote_backend_name == 'dbfs': + path = backend_kwargs['path'] + if path.startswith(MLFLOW_DBFS_PATH_PREFIX): + remote_backend_cls = MLFlowObjectStore + else: + # Validate if the path conforms to the requirements for UC volume paths + UCObjectStore.validate_path(path) + remote_backend_cls = UCObjectStore + else: + remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None) + if remote_backend_cls is None: + supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs'] + raise ValueError( + f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})', + ) + + return remote_backend_cls(**backend_kwargs) diff --git a/composer/utils/remote_uploader.py b/composer/utils/remote_uploader.py index c26c73a319..33793e7c91 100644 --- a/composer/utils/remote_uploader.py +++ b/composer/utils/remote_uploader.py @@ -12,13 +12,20 @@ import time import uuid from concurrent.futures import Future, ProcessPoolExecutor -from typing import List +from enum import Enum +from typing import Any, Optional -from composer.utils.dist import get_local_rank +from composer.utils.dist import broadcast_object_list, get_global_rank, get_local_rank from composer.utils.file_helpers import ( - maybe_create_object_store_from_uri, + parse_uri, + validate_credentials, ) -from composer.utils.object_store.object_store import ObjectStore, ObjectStoreTransientError +from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX, MLFlowObjectStore +from composer.utils.object_store.object_store import ( + ObjectStore, + ObjectStoreTransientError, +) +from composer.utils.object_store.utils import build_remote_backend from composer.utils.retrying import retry log = logging.getLogger(__name__) @@ -26,16 +33,55 @@ __all__ = ['RemoteUploader'] +class RemoteFilesExistingCheckStatus(Enum): + EXIST = 1 + TIMEOUT = 2 + ERROR = 3 + + +def _check_remote_files_exists( + remote_backend_name: str, + backend_kwargs: dict[str, Any], + remote_checkpoint_file_names: list[str], + main_process_pid: int, + is_remote_upload_failed: multiprocessing.Event, # pyright: ignore[reportGeneralTypeIssues] + max_wait_time_in_seconds: int = 3600, + wait_before_next_try_in_seconds: float = 30, +): + start_time = time.time() + object_store = build_remote_backend(remote_backend_name, backend_kwargs) + + for remote_file_name in remote_checkpoint_file_names: + while True: + if is_remote_upload_failed.is_set(): + log.debug(f'Stop symlink uploading since the checkpoint files uploading failed') + return RemoteFilesExistingCheckStatus.ERROR + # Return if parent process exits + try: + os.kill(main_process_pid, 0) + except OSError: + return RemoteFilesExistingCheckStatus.ERROR + try: + object_store.get_object_size(remote_file_name) + break + except Exception as e: + if not isinstance(e, FileNotFoundError): + log.debug(f'Got exception {type(e)}: {str(e)} when accessing remote file {remote_file_name}') + time.sleep(wait_before_next_try_in_seconds) + if time.time() - start_time > max_wait_time_in_seconds: + return RemoteFilesExistingCheckStatus.TIMEOUT + return RemoteFilesExistingCheckStatus.EXIST + + def _upload_file_to_object_store( - remote_folder: str, + remote_backend_name: str, + backend_kwargs: dict[str, Any], remote_file_name: str, local_file_path: str, overwrite: bool, num_attempts: int, ) -> int: - object_store: ObjectStore = maybe_create_object_store_from_uri( - remote_folder, - ) # pyright: ignore[reportGeneralTypeIssues] + object_store = build_remote_backend(remote_backend_name, backend_kwargs) @retry(ObjectStoreTransientError, num_attempts=num_attempts) def upload_file(retry_index: int = 0): @@ -72,6 +118,7 @@ class RemoteUploader: def __init__( self, remote_folder: str, + backend_kwargs: Optional[dict[str, Any]] = None, num_concurrent_uploads: int = 2, num_attempts: int = 3, ): @@ -84,18 +131,80 @@ def __init__( # A folder to use for staging uploads self._tempdir = tempfile.TemporaryDirectory() self._upload_staging_folder = self._tempdir.name + self.remote_backend_name, self.remote_bucket_name, self.path = parse_uri(remote_folder) - self.num_attempts = num_attempts + self.backend_kwargs: dict[str, Any] = backend_kwargs if backend_kwargs is not None else {} + if self.remote_backend_name in ['s3', 'oci', 'gs'] and 'bucket' not in self.backend_kwargs: + self.backend_kwargs['bucket'] = self.remote_bucket_name + elif self.remote_backend_name == 'libcloud': + if 'container' not in self.backend_kwargs: + self.backend_kwargs['container'] = self.remote_bucket_name + elif self.remote_backend_name == 'azure': + self.remote_backend_name = 'libcloud' + self.backend_kwargs = { + 'provider': 'AZURE_BLOBS', + 'container': self.remote_bucket_name, + 'key_environ': 'AZURE_ACCOUNT_NAME', + 'secret_environ': 'AZURE_ACCOUNT_ACCESS_KEY', + } + elif self.remote_backend_name == 'dbfs': + self.backend_kwargs['path'] = self.path + elif self.remote_backend_name == 'wandb': + raise NotImplementedError( + f'There is no implementation for WandB via URI. Please use ' + 'WandBLogger with log_artifacts set to True.', + ) + else: + raise NotImplementedError( + f'There is no implementation for the cloud backend {self.remote_backend_name} via URI. Please use ' + 'one of the supported object stores (s3, oci, gs, azure, dbfs).', + ) - self.executor = ProcessPoolExecutor( + self.num_attempts = num_attempts + self._remote_backend: Optional[ObjectStore] = None + mp_context = multiprocessing.get_context('spawn') + self.upload_executor = ProcessPoolExecutor( max_workers=num_concurrent_uploads, - mp_context=multiprocessing.get_context('spawn'), + mp_context=mp_context, ) + self.check_remote_files_exist_executor = ProcessPoolExecutor( + max_workers=2, + mp_context=mp_context, + ) + self.is_remote_upload_failed = mp_context.Manager().Event() # Used internally to track the future status. # If a future completed successfully, we'll remove it from this list # when check_workers() or wait() is called - self.futures: List[Future] = [] + self.futures: list[Future] = [] + + self.pid = os.getpid() + + @property + def remote_backend(self) -> ObjectStore: + if self._remote_backend is None: + self._remote_backend = build_remote_backend(self.remote_backend_name, self.backend_kwargs) + return self._remote_backend + + def init(self): + # If it's dbfs path like: dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}/ + # We need to fill out the experiment_id and run_id + + if get_global_rank() == 0: + + @retry(ObjectStoreTransientError, num_attempts=self.num_attempts) + def _validate_credential_with_retry(): + validate_credentials(self.remote_backend, '.credentials_validated_successfully') + + _validate_credential_with_retry() + if self.path.startswith(MLFLOW_DBFS_PATH_PREFIX): + if get_global_rank() == 0: + assert isinstance(self.remote_backend, MLFlowObjectStore) + self.path = self.remote_backend.get_dbfs_path(self.path) + path_list = [self.path] + broadcast_object_list(path_list, src=0) + self.path = path_list[0] + self.backend_kwargs['path'] = self.path def upload_file_async( self, @@ -114,9 +223,10 @@ def upload_file_async( shutil.copy2(file_path, copied_path) # Async upload file - future = self.executor.submit( + future = self.upload_executor.submit( _upload_file_to_object_store, - remote_folder=self.remote_folder, + remote_backend_name=self.remote_backend_name, + backend_kwargs=self.backend_kwargs, remote_file_name=remote_file_name, local_file_path=copied_path, overwrite=overwrite, @@ -132,12 +242,13 @@ def check_workers(self): 1. if it completed with exception, raise that exception 2. if it completed without exception, remove it from self.futures """ - done_futures: List[Future] = [] + done_futures: list[Future] = [] for future in self.futures: if future.done(): # future.exception is a blocking call exception_or_none = future.exception() if exception_or_none is not None: + self.is_remote_upload_failed.set() raise exception_or_none else: done_futures.append(future) @@ -153,6 +264,7 @@ def wait(self): for future in self.futures: exception_or_none = future.exception() if exception_or_none is not None: + self.is_remote_upload_failed.set() raise exception_or_none self.futures = [] @@ -165,4 +277,25 @@ def wait_and_close(self): """ # make sure all workers are either running, or completed successfully self.wait() - self.executor.shutdown(wait=True) + self.upload_executor.shutdown(wait=True) + self.check_remote_files_exist_executor.shutdown(wait=True) + log.debug('Finished all uploading tasks, closing RemoteUploader') + + def check_remote_files_exist_async( + self, + remote_checkpoint_file_names: list[str], + max_wait_time_in_seconds: int = 3600, + wait_before_next_try_in_seconds: float = 30, + ): + future = self.check_remote_files_exist_executor.submit( + _check_remote_files_exists, + remote_backend_name=self.remote_backend_name, + backend_kwargs=self.backend_kwargs, + remote_checkpoint_file_names=remote_checkpoint_file_names, + main_process_pid=self.pid, + is_remote_upload_failed=self.is_remote_upload_failed, + max_wait_time_in_seconds=max_wait_time_in_seconds, + wait_before_next_try_in_seconds=wait_before_next_try_in_seconds, + ) + self.futures.append(future) + return future diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 553d8d9b60..f54d1f69e1 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -54,7 +54,7 @@ from composer.loggers import RemoteUploaderDownloader from composer.models import ComposerModel as ComposerModel from composer.optim.scheduler import ConstantScheduler -from composer.utils import LibcloudObjectStore +from composer.utils import LibcloudObjectStore, RemoteUploader from composer.utils import ensure_tuple as ensure_tuple try: @@ -246,6 +246,29 @@ def _new_RemoteUploaderDownloader_init(self, fake_ellipses: None = None, **kwarg RemoteUploaderDownloader.__init__ = _new_RemoteUploaderDownloader_init # type: ignore +# Patch RemoteUploader __init__ function to replace arguments while preserving type +_original_RemoteUploader_init = RemoteUploader.__init__ + + +def _new_RemoteUploader_init(self, fake_ellipses: None = None, **kwargs: Any): + os.makedirs('./object_store', exist_ok=True) + kwargs.update( + num_concurrent_uploads=1, + remote_folder='libcloud://.', + backend_kwargs={ + 'provider': 'local', + 'container': '.', + 'provider_kwargs': { + 'key': os.path.abspath('./object_store'), + }, + }, + num_attempts=1, + ) + _original_RemoteUploader_init(self, **kwargs) + + +RemoteUploader.__init__ = _new_RemoteUploader_init + # Patch ObjectStore __init__ function to replace arguments while preserving type _original_libcloudObjectStore_init = LibcloudObjectStore.__init__ diff --git a/tests/loggers/test_remote_uploader_downloader.py b/tests/loggers/test_remote_uploader_downloader.py index 1f877d2dd9..b25e23a717 100644 --- a/tests/loggers/test_remote_uploader_downloader.py +++ b/tests/loggers/test_remote_uploader_downloader.py @@ -77,7 +77,7 @@ def object_store_test_helper( # Patching does not work when using multiprocessing with spawn, so we also # patch to use fork fork_context = multiprocessing.get_context('fork') - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri='s3://{remote_dir}', @@ -227,7 +227,7 @@ def get_object_size(self, object_name: str) -> int: return super().get_object_size(object_name) fork_context = multiprocessing.get_context('fork') - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', RetryDummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', RetryDummyObjectStore): with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri=f"s3://{tmp_path}/'object_store_backend", @@ -263,7 +263,7 @@ def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_stat # Patching does not work when using multiprocessing with spawn, so we also # patch to use fork fork_context = multiprocessing.get_context('fork') - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('composer.loggers.remote_uploader_downloader.multiprocessing.get_context', lambda _: fork_context): # Create the object store logger remote_uploader_downloader = RemoteUploaderDownloader( @@ -307,7 +307,7 @@ def test_race_with_overwrite(tmp_path: pathlib.Path, use_procs: bool, dummy_stat def test_close_on_failure(tmp_path: pathlib.Path, dummy_state: State): """Test that .close() and .post_close() does not hang even when a worker crashes.""" - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): # Create the object store logger remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri=f"s3://{tmp_path}/'object_store_backend", @@ -355,9 +355,9 @@ def test_close_on_failure(tmp_path: pathlib.Path, dummy_state: State): def test_valid_backend_names(): valid_backend_names = ['s3', 'libcloud', 'sftp'] - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore') as _, \ - patch('composer.loggers.remote_uploader_downloader.SFTPObjectStore') as _, \ - patch('composer.loggers.remote_uploader_downloader.LibcloudObjectStore') as _: + with patch('composer.utils.object_store.utils.S3ObjectStore') as _, \ + patch('composer.utils.object_store.utils.SFTPObjectStore') as _, \ + patch('composer.utils.object_store.utils.LibcloudObjectStore') as _: for name in valid_backend_names: remote_uploader_downloader = RemoteUploaderDownloader(bucket_uri=f'{name}://not-a-real-bucket') # Access the remote_backend property so that it is built @@ -374,7 +374,7 @@ def test_valid_backend_names(): def test_exception_queue_works(tmp_path: pathlib.Path, dummy_state: State): """Test that exceptions get put on the exception queue and get thrown""" - with patch('composer.loggers.remote_uploader_downloader.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): # Create the object store logger remote_uploader_downloader = RemoteUploaderDownloader( bucket_uri=f"s3://{tmp_path}/'object_store_backend", diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 9912563eb8..ede864d13b 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -4,6 +4,7 @@ import contextlib import copy import io +import multiprocessing import os import pathlib import re @@ -25,12 +26,11 @@ from composer.algorithms import NoOpModel from composer.callbacks import CheckpointSaver from composer.core import Callback, Time, TimeUnit -from composer.loggers import RemoteUploaderDownloader, remote_uploader_downloader from composer.metrics import MAP from composer.optim import ExponentialScheduler from composer.trainer import trainer from composer.trainer.trainer import Trainer -from composer.utils import dist, is_tar, reproducibility +from composer.utils import dist, is_tar, remote_uploader, reproducibility from composer.utils.checkpoint import ( _COMPOSER_STATES_FILENAME, PartialFilePath, @@ -52,6 +52,7 @@ device, ) from tests.common.markers import world_size +from tests.utils.test_remote_uploader import DummyObjectStore class DummyStatefulCallback(Callback): @@ -309,30 +310,6 @@ def get_trainer(self, **kwargs): model = SimpleConvModel() return Trainer(model=model, **kwargs) - @pytest.mark.parametrize('add_remote_ud', [True, False]) - def test_s3_uri_creates_remote_ud(self, add_remote_ud: bool, monkeypatch: MonkeyPatch): - mock_validate_credentials = MagicMock() - monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials) - if add_remote_ud: - with pytest.warns(UserWarning): - trainer = self.get_trainer( - save_folder='s3://bucket_name/{run_name}/checkpoints', - loggers=[ - RemoteUploaderDownloader('s3://bucket_name', file_path_format_string='{remote_file_name}'), - ], - ) - else: - trainer = self.get_trainer(save_folder='s3://bucket_name/{run_name}/checkpoints') - - remote_uds = [ - logger_dest for logger_dest in trainer.logger.destinations - if isinstance(logger_dest, RemoteUploaderDownloader) - ] - assert len(remote_uds) == 1 - remote_ud = remote_uds[0] - assert remote_ud.remote_backend_name == 's3' - assert remote_ud.remote_bucket_name == 'bucket_name' - @pytest.mark.parametrize('uri', ['wandb://foo/bar', 'gcs://foo/bar', 'sftp://foo/bar"']) def test_other_uris_error_out(self, uri: str): with pytest.raises(NotImplementedError): @@ -394,7 +371,7 @@ def test_checkpoint_saver_properly_constructed( monkeypatch: MonkeyPatch, ): mock_validate_credentials = MagicMock() - monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials) + monkeypatch.setattr(remote_uploader, 'validate_credentials', mock_validate_credentials) trainer = self.get_trainer(save_folder=save_folder) @@ -646,6 +623,71 @@ def test_checkpoint_multiple_callbacks( assert id(trainer._checkpoint_saver) == id(checkpoint_savers[0]) assert len([cb for cb in trainer.state.callbacks if isinstance(cb, CheckpointSaver)]) == len(checkpoint_savers) + @pytest.mark.parametrize(('upload_success'), [True, False]) + def test_checkpoint_remote_symlink( + self, + upload_success: bool, + ): + import multiprocessing + fork_context = multiprocessing.get_context('fork') + tmp_dir = tempfile.TemporaryDirectory() + + def _get_tmp_dir(self): + return tmp_dir + + class _AlwaysFailDummyObjectStore(DummyObjectStore): + + def upload_object(self, object_name, filename, callback=None): + # Only allows to upload symlink to simulate + # the situation that checkpoint file uploading fails + if 'symlink' in object_name or 'credentials_validated_successfully' in object_name: + return super().upload_object(object_name, filename, callback) + raise RuntimeError('Raise Error intentionally') + + if upload_success: + MockObjectStore = DummyObjectStore + else: + MockObjectStore = _AlwaysFailDummyObjectStore + + with patch('composer.utils.object_store.utils.S3ObjectStore', MockObjectStore): + with patch('tests.utils.test_remote_uploader.DummyObjectStore.get_tmp_dir', _get_tmp_dir): + with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): + train_dataset = RandomClassificationDataset(size=10) + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=2, + sampler=dist.get_sampler(train_dataset), + ) + + trainer = Trainer( + model=SimpleModel(), + train_dataloader=train_dataloader, + save_interval='1ba', + max_duration='1ba', + save_folder='S3://whatever/', + ) + symlink_filepath = os.path.join(tmp_dir.name, 'latest-rank0.pt.symlink') + if upload_success: + trainer.fit() + with open(symlink_filepath, 'r') as f: + assert f.read() == 'ep0-ba1-rank0.pt' + else: + assert trainer._checkpoint_saver is not None + trainer._checkpoint_saver._symlink_upload_wait_before_next_try_in_seconds = 0.01 + trainer._checkpoint_saver.upload_timeout_in_seconds = 1 + with pytest.raises(RuntimeError, match='Raise Error intentionally'): + trainer.fit() + assert os.path.exists(symlink_filepath) == False + + def post_close(self): + return + + assert trainer._checkpoint_saver is not None + trainer._checkpoint_saver.post_close = post_close.__get__( + trainer._checkpoint_saver, + CheckpointSaver, + ) + class TestCheckpointLoading: @@ -709,25 +751,6 @@ def get_trainer( **kwargs, ) - def get_logger(self, tmp_path: pathlib.Path): - """Returns an object store logger that saves locally.""" - remote_dir = str(tmp_path / 'object_store') - os.makedirs(remote_dir, exist_ok=True) - - return RemoteUploaderDownloader( - bucket_uri='libcloud://.', - backend_kwargs={ - 'provider': 'local', - 'container': '.', - 'provider_kwargs': { - 'key': remote_dir, - }, - }, - num_concurrent_uploads=1, - use_procs=False, - upload_staging_folder=str(tmp_path / 'staging_folder'), - ) - @world_size(1, 2) @device('cpu', 'gpu') @pytest.mark.parametrize('use_object_store', [True, False]) @@ -758,9 +781,6 @@ def test_autoresume( if delete_local and not use_object_store: pytest.skip('Invalid test setting.') - if use_object_store: - pytest.importorskip('libcloud') - latest_filename = 'latest-rank{rank}' + file_extension if test_slashed: latest_filename = 'testdir/' + latest_filename @@ -768,51 +788,68 @@ def test_autoresume( if is_compressed_pt(latest_filename) and not get_compressor(latest_filename).exists: pytest.skip(reason=f'compressor not found for {latest_filename}') - trainer_1 = self.get_trainer( - latest_filename=latest_filename, - file_extension=file_extension, - save_folder='first', - device=device, - run_name='big-chungus', - autoresume=True, - loggers=[self.get_logger(tmp_path)] if use_object_store else [], - save_metrics=save_metrics, - ) - - # trains the model, saving the checkpoint files - trainer_1.fit() - trainer_1.close() - - if delete_local: - # delete files locally, forcing trainer to look in object store - shutil.rmtree('first') - - trainer_2 = self.get_trainer( - latest_filename=latest_filename, - save_folder='first', - device=device, - run_name='big-chungus', - autoresume=True, - load_path='ignore_me.pt', # this should be ignored - load_ignore_keys=['*'], # this should be ignored - save_overwrite=save_overwrite, - loggers=[self.get_logger(tmp_path)] if use_object_store else [], - ) - - self._assert_weights_equivalent( - trainer_1.state.model, - trainer_2.state.model, - ) - - if save_metrics: - assert self._metrics_equal( - trainer_1.state.train_metrics, - trainer_2.state.train_metrics, - trainer_1.state.eval_metrics, - trainer_2.state.eval_metrics, - ), 'Original metrics do not equal metrics from loaded checkpoint.' - - assert trainer_1.state.run_name == trainer_2.state.run_name + if use_object_store: + save_folder = 's3://bucket_name/first' + else: + save_folder = 'first' + + # Mock S3 object store + fork_context = multiprocessing.get_context('fork') + tmp_dir = tempfile.TemporaryDirectory() + + def _get_tmp_dir(self): + return tmp_dir + + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): + with patch('tests.utils.test_remote_uploader.DummyObjectStore.get_tmp_dir', _get_tmp_dir): + with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): + + trainer_1 = self.get_trainer( + latest_filename=latest_filename, + file_extension=file_extension, + save_folder=save_folder, + device=device, + run_name='big-chungus', + autoresume=True, + save_metrics=save_metrics, + ) + if use_object_store: + assert trainer_1._checkpoint_saver is not None + trainer_1._checkpoint_saver._symlink_upload_wait_before_next_try_in_seconds = 0.01 + + # trains the model, saving the checkpoint files + trainer_1.fit() + trainer_1.close() + + if delete_local: + # delete files locally, forcing trainer to look in object store + shutil.rmtree('first') + + trainer_2 = self.get_trainer( + latest_filename=latest_filename, + save_folder=save_folder, + device=device, + run_name='big-chungus', + autoresume=True, + load_path='ignore_me.pt', # this should be ignored + load_ignore_keys=['*'], # this should be ignored + save_overwrite=save_overwrite, + ) + + self._assert_weights_equivalent( + trainer_1.state.model, + trainer_2.state.model, + ) + + if save_metrics: + assert self._metrics_equal( + trainer_1.state.train_metrics, + trainer_2.state.train_metrics, + trainer_1.state.eval_metrics, + trainer_2.state.eval_metrics, + ), 'Original metrics do not equal metrics from loaded checkpoint.' + + assert trainer_1.state.run_name == trainer_2.state.run_name @pytest.mark.parametrize(('save_folder'), [None, 'first']) def test_autoresume_from_callback( @@ -862,7 +899,7 @@ def test_autoresume_from_callback( def test_load_from_uri(self, load_path: str, load_object_store: Optional[ObjectStore], monkeypatch: MonkeyPatch): mock_validate_credentials = MagicMock() - monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials) + monkeypatch.setattr(remote_uploader, 'validate_credentials', mock_validate_credentials) mock_load_checkpoint = MagicMock() monkeypatch.setattr(trainer.checkpoint, 'load_checkpoint', mock_load_checkpoint) self.get_trainer(load_path=load_path, load_object_store=load_object_store) @@ -882,7 +919,7 @@ def test_load_from_uri(self, load_path: str, load_object_store: Optional[ObjectS ) def test_other_backends_error(self, load_path: str, monkeypatch: MonkeyPatch): mock_validate_credentials = MagicMock() - monkeypatch.setattr(remote_uploader_downloader, '_validate_credentials', mock_validate_credentials) + monkeypatch.setattr(remote_uploader, 'validate_credentials', mock_validate_credentials) with pytest.raises(NotImplementedError): self.get_trainer(load_path=load_path) @@ -1197,29 +1234,37 @@ def _stateful_callbacks_equal(self, callbacks1, callbacks2): return cb1.random_value == cb2.random_value def test_load_weights_object_store(self, tmp_path): - - pytest.importorskip('libcloud') - - trainer_1 = self.get_trainer( - save_folder='{run_name}/checkpoints', - loggers=[self.get_logger(tmp_path)], - run_name='electric-zebra', - ) - trainer_1.fit() - trainer_1.close() - - trainer_2 = self.get_trainer( - loggers=[self.get_logger(tmp_path)], - run_name='electric-zebra', - load_path='electric-zebra/checkpoints/latest-rank0.pt', - load_object_store=self.get_logger(tmp_path), - ) - - # check weights loaded properly - self._assert_weights_equivalent( - trainer_1.state.model, - trainer_2.state.model, - ) + # Mock S3 object store + fork_context = multiprocessing.get_context('fork') + tmp_dir = tempfile.TemporaryDirectory() + + def _get_tmp_dir(self): + return tmp_dir + + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): + with patch('tests.utils.test_remote_uploader.DummyObjectStore.get_tmp_dir', _get_tmp_dir): + with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): + save_folder = 's3://my_bucket/{run_name}/checkpoints' + trainer_1 = self.get_trainer( + save_folder=save_folder, + run_name='electric-zebra', + ) + assert trainer_1._checkpoint_saver is not None + trainer_1._checkpoint_saver._symlink_upload_wait_before_next_try_in_seconds = 0.01 + trainer_1.fit() + trainer_1.close() + + trainer_2 = self.get_trainer( + run_name='electric-zebra', + load_path='electric-zebra/checkpoints/latest-rank0.pt', + load_object_store=DummyObjectStore(), + ) + + # check weights loaded properly + self._assert_weights_equivalent( + trainer_1.state.model, + trainer_2.state.model, + ) @pytest.mark.parametrize( 'run_name,save_folder,latest_filename', diff --git a/tests/utils/test_remote_uploader.py b/tests/utils/test_remote_uploader.py index 847abb369d..2e41e91d18 100644 --- a/tests/utils/test_remote_uploader.py +++ b/tests/utils/test_remote_uploader.py @@ -20,7 +20,7 @@ class DummyObjectStore(ObjectStore): """Dummy ObjectStore implementation that is backed by a local directory.""" def __init__(self, **kwargs: Dict[str, Any]) -> None: - self.tmp_dir = tempfile.TemporaryDirectory() + self.tmp_dir = self.get_tmp_dir() self.root = self.tmp_dir.name self.sleep_sec = 0 self.dest_filename = '' @@ -28,6 +28,9 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None: def raise_error(self): return False + def get_tmp_dir(self): + return tempfile.TemporaryDirectory() + def upload_object( self, object_name: str, @@ -38,6 +41,7 @@ def upload_object( raise RuntimeError('Raise Error intentionally') time.sleep(self.sleep_sec) dest_filename = pathlib.Path(self.root) / object_name + os.makedirs(os.path.dirname(dest_filename), exist_ok=True) shutil.copy2(filename, dest_filename) self.dest_filename = dest_filename @@ -46,6 +50,16 @@ def get_object_size(self, object_name: str) -> int: size = os.stat(object_path).st_size return size + def download_object( + self, + object_name: str, + filename: Union[str, pathlib.Path], + overwrite: bool = False, + callback: Optional[Callable[[int, int], None]] = None, + ): + object_path = pathlib.Path(self.root) / object_name + shutil.copy2(object_path, filename) + def test_upload_mutliple_files(): fork_context = multiprocessing.get_context('fork') @@ -54,7 +68,7 @@ def test_upload_mutliple_files(): def _get_tmp_dir(): return tmp_dir - with patch('composer.utils.file_helpers.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('tempfile.TemporaryDirectory', _get_tmp_dir): with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): remote_uploader = RemoteUploader( @@ -99,7 +113,7 @@ def _get_tmp_dir(): return remote_tmp_dir fork_context = multiprocessing.get_context('fork') - with patch('composer.utils.file_helpers.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('tempfile.TemporaryDirectory', _get_tmp_dir): with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): remote_uploader = RemoteUploader(remote_folder='S3://whatever/path',) @@ -145,7 +159,7 @@ def raise_error(self): return True fork_context = multiprocessing.get_context('fork') - with patch('composer.utils.file_helpers.S3ObjectStore', AlwaysFailDummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', AlwaysFailDummyObjectStore): with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): remote_uploader = RemoteUploader(remote_folder='S3://whatever/path',) tmp_dir = tempfile.TemporaryDirectory() @@ -168,7 +182,7 @@ def raise_error(self): def test_wait(): fork_context = multiprocessing.get_context('fork') - with patch('composer.utils.file_helpers.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): remote_uploader = RemoteUploader( remote_folder='S3://whatever/path', @@ -197,7 +211,7 @@ def test_wait(): def test_wait_and_close(): fork_context = multiprocessing.get_context('fork') - with patch('composer.utils.file_helpers.S3ObjectStore', DummyObjectStore): + with patch('composer.utils.object_store.utils.S3ObjectStore', DummyObjectStore): with patch('composer.utils.remote_uploader.multiprocessing.get_context', lambda _: fork_context): remote_uploader = RemoteUploader( remote_folder='S3://whatever/path',