diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 660c315c8e..3da777a0ec 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -312,10 +312,7 @@ def init(self, state: State, logger: Logger) -> None: if self.run_name is None: self.run_name = state.run_name - if hasattr(state, 'device'): - self._global_exception_occurred = state.device.tensor_to_device(torch.tensor([0], dtype=torch.uint8),) - else: - self._global_exception_occurred = 0 + self._global_exception_occurred = 0 # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name) @@ -615,10 +612,7 @@ def post_close(self): if hasattr(self, 'monitor_process'): # Check if there is an uncaught exception, which means `post_close()` is triggered # due to program crash. - if isinstance(self._global_exception_occurred, torch.Tensor): - finish_with_exception = (self._global_exception_occurred == 1).item() - else: - finish_with_exception = (self._global_exception_occurred == 1) + finish_with_exception = self._global_exception_occurred == 1 if finish_with_exception: self.monitor_process.crash() return