From 9dff6f8128bd8b658d664c7c7950d71db282c4a3 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Tue, 25 Jun 2024 19:10:20 -0700 Subject: [PATCH] handle crash --- composer/loggers/mlflow_logger.py | 14 +++++++++-- tests/loggers/test_mlflow_logger.py | 39 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 526a7962fd..101ca0d126 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -519,14 +519,24 @@ def log_images( step=step, ) + def fit_end(self, state: State, logger: Logger): + import mlflow + + mlflow.flush_async_logging() + # If `fit_end` is successfully executed, the run is considered successful. + mlflow.end_run(status="FINISHED") + def post_close(self): if self._enabled: import mlflow assert isinstance(self._run_id, str) mlflow.flush_async_logging() - self._mlflow_client.set_terminated(self._run_id) - mlflow.end_run() + status = mlflow.get_run(self._run_id).info.status + if status == "RUNNING": + # If the run is still running, it is considered failed because `post_close` was + # called on runtime failure. + mlflow.end_run(status="FAILED") def _convert_to_mlflow_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray: diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 5ee6aab7a5..59de5b165d 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -598,6 +598,45 @@ def test_mlflow_register_uc_error(tmp_path, monkeypatch): ) +@pytest.mark.parametrize('crash_during_fit', [True, False]) +def test_mlflow_run_status(tmp_path, crash_during_fit): + mlflow = pytest.importorskip('mlflow') + + mlflow_uri = tmp_path / Path('my-test-mlflow-uri') + experiment_name = 'mlflow_logging_test' + # mock_state = MagicMock() + # mock_logger = MagicMock() + + test_mlflow_logger = MLFlowLogger( + tracking_uri=mlflow_uri, + experiment_name=experiment_name, + log_system_metrics=True, + run_name='test_run', + ) + trainer = Trainer( + model=SimpleModel(), + loggers=test_mlflow_logger, + train_dataloader=DataLoader(RandomClassificationDataset(size=64), batch_size=4), + eval_dataloader=DataLoader(RandomClassificationDataset(size=64), batch_size=4), + max_duration=f'4ba', + eval_interval='1ba', + ) + + if crash_during_fit: + with patch.object(trainer, 'fit', side_effect=Exception('mocked exception')): + with pytest.raises(Exception, match='mocked exception'): + trainer.fit() + + test_mlflow_logger.post_close() + status = mlflow.get_run(test_mlflow_logger._run_id).info.status + assert status == 'FAILED' + + else: + trainer.fit() + test_mlflow_logger.post_close() + status = mlflow.get_run(test_mlflow_logger._run_id).info.status + assert status == 'FINISHED' + @device('cpu') def test_mlflow_log_image_works(tmp_path, device): pytest.importorskip('mlflow')