Skip to content

Commit

Permalink
Add objectstore registry
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Oct 23, 2024
1 parent 835dd76 commit f7a6d56
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -393,13 +393,41 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.
# Registry for object store creation functions
object_store_registry: dict[str, Callable[[str, str], ObjectStore]] = {}

Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned)

def register_object_store(backend: str, factory_func: Callable[[str, str], ObjectStore]):
"""Registers a new object store backend to the registry.
Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from
backend (str): The backend name (e.g., 's3', 'oci').
factory_func (Callable): A function that accepts bucket_name and path and returns an ObjectStore instance.
"""
object_store_registry[backend] = factory_func


# Register default object stores
register_object_store('s3', lambda bucket, path: S3ObjectStore(bucket=bucket))
register_object_store('gs', lambda bucket, path: GCSObjectStore(bucket=bucket))
register_object_store('oci', lambda bucket, path: OCIObjectStore(bucket=bucket))
register_object_store(
'azure',
lambda bucket,
path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
)


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an ObjectStore from supported URI formats.
Args:
uri (str): The path to (maybe) create an ObjectStore from.
Raises:
NotImplementedError: Raises when the URI format is not supported.
Expand All @@ -408,54 +436,38 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, path = parse_uri(uri)

# If backend is empty, assume local path and return None
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)
elif backend == 'wandb':

# Check if backend is registered
if backend in object_store_registry:
return object_store_registry[backend](bucket_name, path)

# Handle special cases like WandB, MLFlow, etc.
if backend == 'wandb':
raise NotImplementedError(
f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger',
)
elif backend == 'gs':
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)

# If backend is unknown, raise NotImplementedError
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI.')


def maybe_create_remote_uploader_downloader_from_uri(
Expand Down

0 comments on commit f7a6d56

Please sign in to comment.