diff --git a/composer/core/state.py b/composer/core/state.py index a2d59b5e57..a0e50303d2 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -257,6 +257,7 @@ class State(Serializable): rank_zero_seed (int): The seed used on the rank zero process. It is assumed that each rank's seed is ``rank_zero_seed + dist.get_global_rank()``. run_name (str): The name for this training run. + parent_run_id (int): The mlflow run id of rank 0 process. device (Device): The device used by this process. The trainer moves the model and loaded data to this device. device_train_microbatch_size (int, optional): The microbatch size for each device during training. auto_microbatching (bool, optional): Whether automatic microbatching is enabled. @@ -457,6 +458,7 @@ def __init__( self.rank_zero_seed = rank_zero_seed self.model = model self.run_name = run_name + self.parent_run_id = None self.device = device self.device_train_microbatch_size = device_train_microbatch_size self.auto_microbatching = auto_microbatching diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 8f1250cee3..2b6d226fcb 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -13,6 +13,9 @@ from composer.loggers.logger import Logger from composer.loggers.logger_destination import LoggerDestination from composer.utils import MissingConditionalImportError, dist +from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID +from mlflow.entities import Metric, Param +from mlflow.utils.time_utils import get_current_time_millis __all__ = ['MLFlowLogger'] @@ -46,13 +49,13 @@ def __init__( raise MissingConditionalImportError(extra_deps_group='mlflow', conda_package='mlflow', conda_channel='conda-forge') from e - del mlflow self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 self.run_name = run_name self.experiment_name = experiment_name self._rank_zero_only = rank_zero_only self.tracking_uri = tracking_uri + self._mlflow_client = mlflow.MlflowClient(tracking_uri) def init(self, state: State, logger: Logger) -> None: import mlflow @@ -70,36 +73,62 @@ def init(self, state: State, logger: Logger) -> None: self.run_name += f'-rank{dist.get_global_rank()}' if self._enabled: - if self.tracking_uri is not None: - mlflow.set_tracking_uri(self.tracking_uri) - - # set experiment - env_exp_id = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_ID.name, None) - if env_exp_id is not None: - mlflow.set_experiment(experiment_id=env_exp_id) - else: - mlflow.set_experiment(experiment_name=self.experiment_name) # start run - env_run_id = os.getenv(mlflow.environment_variables.MLFLOW_RUN_ID.name, None) - if env_run_id is not None: - mlflow.start_run(run_id=env_run_id) + if dist.get_global_rank() == 0: + if self.tracking_uri is not None: + mlflow.set_tracking_uri(self.tracking_uri) + # set experiment + env_exp_id = os.getenv(mlflow.environment_variables.MLFLOW_EXPERIMENT_ID.name, None) + if env_exp_id is not None: + mlflow.set_experiment(experiment_id=env_exp_id) + else: + parent_exp_id = mlflow.set_experiment(experiment_name=self.experiment_name).experiment_id + env_run_id = os.getenv(mlflow.environment_variables.MLFLOW_RUN_ID.name, None) + if env_run_id is None: + parent_run = mlflow.start_run(run_name=self.run_name) + parent_run_id = parent_run.info.run_id + else: + mlflow.start_run(run_id=env_run_id, run_name=self.run_name) + parent_run_id = env_run_id + mlflow.end_run() + else: - mlflow.start_run(run_name=self.run_name) + parent_run_id = None + parent_exp_id = None + + run_id_list = [parent_run_id, parent_exp_id] + dist.broadcast_object_list(run_id_list, src=0) + parent_run_id = run_id_list[0] + parent_exp_id = run_id_list[1] + # mlflow.set_experiment(experiment_id=parent_exp_id) + # mlflow.set_tag(MLFLOW_PARENT_RUN_ID, parent_run_id) + self.run = self._mlflow_client.create_run( + experiment_id=parent_exp_id, + run_name=self.run_name, + tags={MLFLOW_PARENT_RUN_ID: parent_run_id}) def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: - import mlflow if self._enabled: # Convert all metrics to floats to placate mlflow. metrics = {k: float(v) for k, v in metrics.items()} - mlflow.log_metrics(metrics=metrics, step=step) + timestamp = get_current_time_millis() + metrics_arr = [Metric(key, value, timestamp, step or 0) for key, value in metrics.items()] + self._mlflow_client.log_batch( + run_id=self.run.info.run_id, + metrics=metrics_arr, + params=[], + tags=[]) def log_hyperparameters(self, hyperparameters: Dict[str, Any]): - import mlflow if self._enabled: - mlflow.log_params(params=hyperparameters) + params_arr = [Param(key, str(value)) for key, value in hyperparameters.items()] + self._mlflow_client.log_batch( + run_id=self.run.info.run_id, metrics=[], params=params_arr, tags=[]) def post_close(self): import mlflow if self._enabled: mlflow.end_run() + if dist.get_global_rank() == 0: + mlflow.end_run()