From afe22c2a2acd2baa3d7af5e4cfc63c0bc52fb7c1 Mon Sep 17 00:00:00 2001 From: Googler Date: Tue, 13 Apr 2021 14:18:24 -0700 Subject: [PATCH] Adds telemetry to record whether explainable AI sdk is installed on the local environment. Also fixes readability and lint errors in existing code. PiperOrigin-RevId: 368293048 --- .../utils/google_api_client.py | 408 +++++------ .../tests/unit/google_api_client_test.py | 647 +++++++++--------- 2 files changed, 547 insertions(+), 508 deletions(-) diff --git a/src/python/tensorflow_cloud/utils/google_api_client.py b/src/python/tensorflow_cloud/utils/google_api_client.py index f65d9039..b4003569 100644 --- a/src/python/tensorflow_cloud/utils/google_api_client.py +++ b/src/python/tensorflow_cloud/utils/google_api_client.py @@ -46,259 +46,275 @@ class ClientEnvironment(enum.Enum): - """Types of client environment for telemetry reporting.""" - UNKNOWN = 0 - KAGGLE_NOTEBOOK = 1 - HOSTED_NOTEBOOK = 2 - DLVM = 3 - DL_CONTAINER = 4 - COLAB = 5 + """Types of client environment for telemetry reporting.""" + UNKNOWN = 0 + KAGGLE_NOTEBOOK = 1 + HOSTED_NOTEBOOK = 2 + DLVM = 3 + DL_CONTAINER = 4 + COLAB = 5 class TFCloudHttpRequest(googleapiclient_http.HttpRequest): - """HttpRequest builder that sets a customized useragent header for TF Cloud. + """HttpRequest builder that sets a customized useragent header for TF Cloud. This is used to track the usage of the TF Cloud. - """ + """ + + # Class property for passing additional telemetry fields to constructor. + _telemetry_dict = {} - # Class property for passing additional telemetry fields to constructor. - _telemetry_dict = {} + def __init__(self, *args, **kwargs): + """Construct a HttpRequest. - def __init__(self, *args, **kwargs): - """Construct a HttpRequest. + Args: + *args: Positional arguments to pass to the base class constructor. + **kwargs: Keyword arguments to pass to the base class constructor. + """ + headers = kwargs.setdefault("headers", {}) - Args: - *args: Positional arguments to pass to the base class constructor. - **kwargs: Keyword arguments to pass to the base class constructor. - """ - headers = kwargs.setdefault("headers", {}) + comments = {} + if get_or_set_consent_status(): + comments = self._telemetry_dict - comments = {} - if get_or_set_consent_status(): - comments = self._telemetry_dict + # Add the local environment to the user agent header comment field. + comments["client_environment"] = get_client_environment_name() - # Add the local environment to the user agent header comment field. - comments["client_environment"] = get_client_environment_name() + # Add whether the explainable AI SDK is installed on the client. + if is_explainable_ai_sdk_installed(): + comments["explainable_ai_sdk"] = True - # construct comment string using comments dict - user_agent_text = f"{_TF_CLOUD_USER_AGENT_HEADER} (" - for key, value in comments.items(): - user_agent_text = f"{user_agent_text}{key}:{value};" - user_agent_text = f"{user_agent_text})" + # construct comment string using comments dict + user_agent_text = f"{_TF_CLOUD_USER_AGENT_HEADER} (" + for key, value in comments.items(): + user_agent_text = f"{user_agent_text}{key}:{value};" + user_agent_text = f"{user_agent_text})" - headers["user-agent"] = user_agent_text - super(TFCloudHttpRequest, self).__init__(*args, **kwargs) + headers["user-agent"] = user_agent_text + super(TFCloudHttpRequest, self).__init__(*args, **kwargs) - # @classmethod @property chain is only supported in python 3.9+, see - # https://docs.python.org/3/howto/descriptor.html#id27. Using class - # getter and setter instead. - @classmethod - def get_telemetry_dict(cls): - telemetry_dict = cls._telemetry_dict.copy() - return telemetry_dict + # @classmethod @property chain is only supported in python 3.9+, see + # https://docs.python.org/3/howto/descriptor.html#id27. Using class + # getter and setter instead. + @classmethod + def get_telemetry_dict(cls): + telemetry_dict = cls._telemetry_dict.copy() + return telemetry_dict - @classmethod - def set_telemetry_dict(cls, telemetry_dict: Dict[Text, Text]): - cls._telemetry_dict = telemetry_dict.copy() + @classmethod + def set_telemetry_dict(cls, telemetry_dict: Dict[Text, Text]): + cls._telemetry_dict = telemetry_dict.copy() # TODO(b/176097105) Use get_client_environment_name in tfc.run and cloud_fit def get_client_environment_name() -> Text: - """Identifies the local environment where tensorflow_cloud is running. + """Identifies the local environment where tensorflow_cloud is running. - Returns: - ClientEnvironment Enum representing the environment type. - """ - if _get_env_variable(_KAGGLE_ENV_VARIABLE): - logging.info("Kaggle client environment detected.") - return ClientEnvironment.KAGGLE_NOTEBOOK.name - - if _is_module_present("google.colab"): - logging.info("Detected running in COLAB environment.") - return ClientEnvironment.COLAB.name + Returns: + ClientEnvironment Enum representing the environment type. + """ + if _get_env_variable(_KAGGLE_ENV_VARIABLE): + logging.info("Kaggle client environment detected.") + return ClientEnvironment.KAGGLE_NOTEBOOK.name - if _get_env_variable(_DL_ENV_PATH_VARIABLE): - # TODO(b/171720710) Update logic based resolution of the issue. - if _get_env_variable("USER") == "jupyter": - logging.info("Detected running in HOSTED_NOTEBOOK environment.") - return ClientEnvironment.HOSTED_NOTEBOOK.name + if _is_module_present("google.colab"): + logging.info("Detected running in COLAB environment.") + return ClientEnvironment.COLAB.name - # TODO(b/175815580) Update logic based resolution of the issue. - logging.info("Detected running in DLVM environment.") - return ClientEnvironment.DLVM.name + if _get_env_variable(_DL_ENV_PATH_VARIABLE): + # TODO(b/171720710) Update logic based resolution of the issue. + if _get_env_variable("USER") == "jupyter": + logging.info("Detected running in HOSTED_NOTEBOOK environment.") + return ClientEnvironment.HOSTED_NOTEBOOK.name # TODO(b/175815580) Update logic based resolution of the issue. - if _is_module_present("google"): - logging.info("Detected running in DL_CONTAINER environment.") - return ClientEnvironment.DL_CONTAINER.name + logging.info("Detected running in DLVM environment.") + return ClientEnvironment.DLVM.name + + # TODO(b/175815580) Update logic based resolution of the issue. + if _is_module_present("google"): + logging.info("Detected running in DL_CONTAINER environment.") + return ClientEnvironment.DL_CONTAINER.name + + logging.info("Detected running in UNKNOWN environment.") + return ClientEnvironment.UNKNOWN.name - logging.info("Detected running in UNKNOWN environment.") - return ClientEnvironment.UNKNOWN.name + +def is_explainable_ai_sdk_installed() -> bool: + """Checks whether explainable AI SDK is installed or not. + + Returns: + True if the module 'explainable_ai_sdk' is installed locally. + """ + return _is_module_present("explainable_ai_sdk") def _is_module_present(module_name: Text) -> bool: - """Checks if module_name is present in sys.modules. + """Checks if module_name is present in sys.modules. - Args: - module_name: Name of the module to look up in the system modules. - Returns: - True if module exists, False otherwise. - """ - return module_name in sys.modules + Args: + module_name: Name of the module to look up in the system modules. + + Returns: + True if module exists, False otherwise. + """ + return module_name in sys.modules def _get_env_variable(variable_name: Text) -> Union[Text, None]: - """Looks up the value of environment varialbe variable_name. + """Looks up the value of environment varialbe variable_name. - Args: - variable_name: Name of the variable to look up in the environment. - Returns: - A string representing the varialbe value or None if varialbe is not - defined in the environment. - """ - return os.getenv(variable_name) + Args: + variable_name: Name of the variable to look up in the environment. + + Returns: + A string representing the varialbe value or None if varialbe is not + defined in the environment. + """ + return os.getenv(variable_name) def get_or_set_consent_status()-> bool: - """Gets or sets the user consent status for telemetry collection. + """Gets or sets the user consent status for telemetry collection. + + Returns: + If the user has rejected client side telemetry collection returns + False, otherwise it returns true, if a consent flag is not found the + user is notified of telemetry collection and a flag is set. + """ + # Verify if user consent exists and if it is valid for current version of + # tensorflow_cloud + if os.path.exists(_LOCAL_CONFIG_PATH): + with open(_LOCAL_CONFIG_PATH) as config_json: + config_data = json.load(config_json) + if config_data.get(_TELEMETRY_REJECTED_CONFIG): + logging.info("User has opt-out of telemetry reporting.") + return False + if config_data.get(_TELEMETRY_VERSION_CONFIG) == version.__version__: + return True - Returns: - If the user has rejected client side telemetry collection returns - False, otherwise it returns true, if a consent flag is not found the - user is notified of telemetry collection and a flag is set. - """ - # Verify if user consent exists and if it is valid for current version of - # tensorflow_cloud - if os.path.exists(_LOCAL_CONFIG_PATH): - with open(_LOCAL_CONFIG_PATH) as config_json: - config_data = json.load(config_json) - if config_data.get(_TELEMETRY_REJECTED_CONFIG): - logging.info("User has opt-out of telemetry reporting.") - return False - if config_data.get( - _TELEMETRY_VERSION_CONFIG) == version.__version__: - return True - - # Either user has not been notified of telemetry collection or a different - # version of the tensorflow_cloud has been installed since the last - # notification. Notify the user and update the configuration. - logging.info(_PRIVACY_NOTICE) - print(_PRIVACY_NOTICE) - - config_data = {} - config_data[_TELEMETRY_VERSION_CONFIG] = version.__version__ - - # Create the config path if it does not already exist - os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True) - - with open(_LOCAL_CONFIG_PATH, "w") as config_json: - json.dump(config_data, config_json) - return True + # Either user has not been notified of telemetry collection or a different + # version of the tensorflow_cloud has been installed since the last + # notification. Notify the user and update the configuration. + logging.info(_PRIVACY_NOTICE) + print(_PRIVACY_NOTICE) + config_data = {} + config_data[_TELEMETRY_VERSION_CONFIG] = version.__version__ -def optout_metrics_reporting(): - """Set configuration to opt-out of client side metric reporting.""" + # Create the config path if it does not already exist + os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True) - config_data = {} - config_data["telemetry_rejected"] = True + with open(_LOCAL_CONFIG_PATH, "w") as config_json: + json.dump(config_data, config_json) + return True - # Create the config path if it does not already exist - os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True) - with open(_LOCAL_CONFIG_PATH, "w") as config_json: - json.dump(config_data, config_json) +def optout_metrics_reporting(): + """Set configuration to opt-out of client side metric reporting.""" - logging.info("Client side metrics reporting has been disabled.") + config_data = {} + config_data["telemetry_rejected"] = True + # Create the config path if it does not already exist + os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True) -def wait_for_aip_training_job_completion(job_id: Text, project_id: Text)->bool: - """Blocks until the AIP Training job is completed and returns the status. + with open(_LOCAL_CONFIG_PATH, "w") as config_json: + json.dump(config_data, config_json) - Args: - job_id: ID for AIP training job. - project_id: Project under which the AIP Training job is running. - Returns: - True if the job succeeded or it was cancelled, False if the job failed. - """ - # Wait for AIP Training job to finish - job_name = "projects/{}/jobs/{}".format(project_id, job_id) - # Disable cache_discovery to remove excessive info logs see: - # https://github.com/googleapis/google-api-python-client/issues/299 - api_client = discovery.build("ml", "v1", cache_discovery=False) + logging.info("Client side metrics reporting has been disabled.") - request = api_client.projects().jobs().get(name=job_name) +def wait_for_aip_training_job_completion(job_id: Text, project_id: Text)->bool: + """Blocks until the AIP Training job is completed and returns the status. + + Args: + job_id: ID for AIP training job. + project_id: Project under which the AIP Training job is running. + + Returns: + True if the job succeeded or it was cancelled, False if the job failed. + """ + # Wait for AIP Training job to finish + job_name = "projects/{}/jobs/{}".format(project_id, job_id) + # Disable cache_discovery to remove excessive info logs see: + # https://github.com/googleapis/google-api-python-client/issues/299 + api_client = discovery.build("ml", "v1", cache_discovery=False) + + request = api_client.projects().jobs().get(name=job_name) + + response = request.execute() + + counter = 0 + logging.info("Waiting for job to complete, polling status every %s sec.", + _POLL_INTERVAL_IN_SECONDS) + while response["state"] not in ("SUCCEEDED", "FAILED", "CANCELLED"): + logging.info("Attempt number %s to retrieve job status.", counter) + counter += 1 + time.sleep(_POLL_INTERVAL_IN_SECONDS) response = request.execute() - counter = 0 - logging.info( - "Waiting for job to complete, polling status every %s sec.", - _POLL_INTERVAL_IN_SECONDS) - while response["state"] not in ("SUCCEEDED", "FAILED", "CANCELLED"): - logging.info("Attempt number %s to retrieve job status.", counter) - counter += 1 - time.sleep(_POLL_INTERVAL_IN_SECONDS) - response = request.execute() - - if response["state"] == "FAILED": - logging.error("AIP Training job %s failed with error %s.", - job_id, response["errorMessage"]) - return False + if response["state"] == "FAILED": + logging.error("AIP Training job %s failed with error %s.", job_id, + response["errorMessage"]) + return False - # Both CANCELLED and SUCCEEDED status count as successful completion. - logging.info("AIP Training job %s completed with status %s.", - job_id, response["state"]) - return True + # Both CANCELLED and SUCCEEDED status count as successful completion. + logging.info("AIP Training job %s completed with status %s.", job_id, + response["state"]) + return True def is_aip_training_job_running(job_id: Text, project_id: Text)->bool: - """Non-blocking call that checks if AIP Training job is running. + """Non-blocking call that checks if AIP Training job is running. - Args: - job_id: ID for AIP training job. - project_id: Project under which the AIP Training job is running. - Returns: - True if the job is running, False if it has succeeded, failed, or it was - cancelled. - """ - job_name = "projects/{}/jobs/{}".format(project_id, job_id) - # Disable cache_discovery to remove excessive info logs see: - # https://github.com/googleapis/google-api-python-client/issues/299 - api_client = discovery.build("ml", "v1", cache_discovery=False) + Args: + job_id: ID for AIP training job. + project_id: Project under which the AIP Training job is running. - logging.info("Retrieving status for job %s.", job_name) + Returns: + True if the job is running, False if it has succeeded, failed, or it was + cancelled. + """ + job_name = "projects/{}/jobs/{}".format(project_id, job_id) + # Disable cache_discovery to remove excessive info logs see: + # https://github.com/googleapis/google-api-python-client/issues/299 + api_client = discovery.build("ml", "v1", cache_discovery=False) - request = api_client.projects().jobs().get(name=job_name) - response = request.execute() + logging.info("Retrieving status for job %s.", job_name) - return response["state"] not in ("SUCCEEDED", "FAILED", "CANCELLED") + request = api_client.projects().jobs().get(name=job_name) + response = request.execute() + return response["state"] not in ("SUCCEEDED", "FAILED", "CANCELLED") -def stop_aip_training_job(job_id: Text, project_id: Text): - """Cancels a running AIP Training job. - Args: - job_id: ID for AIP training job. - project_id: Project under which the AIP Training job is running. - """ - job_name = "projects/{}/jobs/{}".format(project_id, job_id) - # Disable cache_discovery to remove excessive info logs see: - # https://github.com/googleapis/google-api-python-client/issues/299 - api_client = discovery.build("ml", "v1", cache_discovery=False) - - logging.info("Canceling the job %s.", job_name) - - request = api_client.projects().jobs().cancel(name=job_name) - - try: - request.execute() - except errors.HttpError as e: - if e.resp.status == 400: - logging.info( - # If job is already completed, the request will result in a 400 - # error with similar to 'description': 'Cannot cancel an already - # completed job.' In this case we will absorb the error. - "Job %s has already completed.", job_id) - return - logging.error("Cancel Request for job %s failed.", job_name) - raise e +def stop_aip_training_job(job_id: Text, project_id: Text): + """Cancels a running AIP Training job. + + Args: + job_id: ID for AIP training job. + project_id: Project under which the AIP Training job is running. + """ + job_name = "projects/{}/jobs/{}".format(project_id, job_id) + # Disable cache_discovery to remove excessive info logs see: + # https://github.com/googleapis/google-api-python-client/issues/299 + api_client = discovery.build("ml", "v1", cache_discovery=False) + + logging.info("Canceling the job %s.", job_name) + + request = api_client.projects().jobs().cancel(name=job_name) + + try: + request.execute() + except errors.HttpError as e: + if e.resp.status == 400: + logging.info( + # If job is already completed, the request will result in a 400 + # error with similar to 'description': 'Cannot cancel an already + # completed job.' In this case we will absorb the error. + "Job %s has already completed.", + job_id) + return + logging.error("Cancel Request for job %s failed.", job_name) + raise e diff --git a/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py b/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py index 3d8e23c9..adbb7cae 100644 --- a/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py +++ b/src/python/tensorflow_cloud/utils/tests/unit/google_api_client_test.py @@ -28,318 +28,341 @@ class GoogleApiClientTest(tf.test.TestCase): - def setUp(self): - super(GoogleApiClientTest, self).setUp() - self.addCleanup(mock.patch.stopall) - - # Setting wait time to 1 sec to speed up the tests execution. - google_api_client._POLL_INTERVAL_IN_SECONDS = 1 - self._project_id = "project-a" - self._job_id = "job_id" - - self.mock_discovery_build = mock.patch.object( - discovery, "build", autospec=True - ).start() - self.mock_apiclient = mock.Mock() - self.mock_discovery_build.return_value = self.mock_apiclient - self.mock_request = mock.Mock() - self.mock_apiclient.projects().jobs( - ).get.return_value = self.mock_request - self.mock_apiclient.projects().jobs( - ).cancel.return_value = self.mock_request - self._local_config_path = os.path.join( - self.get_temp_dir(), "config.json") - google_api_client._LOCAL_CONFIG_PATH = self._local_config_path - - # TODO(b/177023448) Remove mock on logging.error here and below. - def test_wait_for_aip_training_job_completion_non_blocking_success(self): - self.mock_request.execute.return_value = { - "state": "SUCCEEDED", - } - status = google_api_client.wait_for_aip_training_job_completion( - self._job_id, self._project_id) - self.assertTrue(status) - self.mock_request.execute.assert_called_once() - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().get.assert_called_with( - name=job_name) - - def test_wait_for_aip_training_job_completion_non_blocking_cancelled(self): - self.mock_request.execute.return_value = { - "state": "CANCELLED", - } - status = google_api_client.wait_for_aip_training_job_completion( - self._job_id, self._project_id) - self.assertTrue(status) - self.mock_request.execute.assert_called_once() - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().get.assert_called_with( - name=job_name) - - def test_wait_for_aip_training_job_completion_non_blocking_failed(self): - self.mock_request.execute.return_value = { - "state": "FAILED", "errorMessage": "test_error_message"} - status = google_api_client.wait_for_aip_training_job_completion( - self._job_id, self._project_id) - self.assertFalse(status) - self.mock_request.execute.assert_called_once() - - def test_wait_for_aip_training_job_completion_multiple_checks_success(self): - self.mock_request.execute.side_effect = [ - {"state": "PREPARING"}, - {"state": "RUNNING"}, - {"state": "SUCCEEDED"} - ] - status = google_api_client.wait_for_aip_training_job_completion( - self._job_id, self._project_id) - self.assertTrue(status) - self.assertEqual(3, self.mock_request.execute.call_count) - - def test_wait_for_aip_training_job_completion_multiple_checks_failed(self): - self.mock_request.execute.side_effect = [ - {"state": "PREPARING"}, - {"state": "RUNNING"}, - {"state": "FAILED", "errorMessage": "test_error_message"}] - status = google_api_client.wait_for_aip_training_job_completion( - self._job_id, self._project_id) - self.assertFalse(status) - self.assertEqual(3, self.mock_request.execute.call_count) - - def test_is_aip_training_job_running_with_completed_job(self): - self.mock_request.execute.side_effect = [ - {"state": "SUCCEEDED"}, - {"state": "CANCELLED"}, - {"state": "FAILED", "errorMessage": "test_error_message"}] - succeeded_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertFalse(succeeded_status) - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().get.assert_called_with( - name=job_name) - cancelled_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertFalse(cancelled_status) - failed_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertFalse(failed_status) - self.assertEqual(3, self.mock_request.execute.call_count) - - def test_is_aip_training_job_running_with_running_job(self): - self.mock_request.execute.side_effect = [ - {"state": "QUEUED"}, - {"state": "PREPARING"}, - {"state": "RUNNING"}, - {"state": "CANCELLING"}] - queued_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertTrue(queued_status) - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().get.assert_called_with( - name=job_name) - preparing_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertTrue(preparing_status) - running_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertTrue(running_status) - canceling_status = google_api_client.is_aip_training_job_running( - self._job_id, self._project_id) - self.assertTrue(canceling_status) - self.assertEqual(4, self.mock_request.execute.call_count) - - def test_stop_aip_training_job_with_running_job(self): - self.mock_request.execute.return_value = {} - google_api_client.stop_aip_training_job(self._job_id, self._project_id) - - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().cancel.assert_called_with( - name=job_name) - - def test_stop_aip_training_job_with_completed_job(self): - self.mock_request.execute.side_effect = errors.HttpError( - httplib2.Response(info={"status": 400}), b"" - ) - google_api_client.stop_aip_training_job(self._job_id, self._project_id) - - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - self.mock_apiclient.projects().jobs().cancel.assert_called_with( - name=job_name) - - def test_stop_aip_training_job_with_failing_request(self): - self.mock_request.execute.side_effect = errors.HttpError( - httplib2.Response(info={"status": 404}), b"") - - job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) - with self.assertRaises(errors.HttpError): - google_api_client.stop_aip_training_job( - self._job_id, self._project_id) - self.mock_apiclient.projects().jobs().cancel.assert_called_with( - name=job_name) - - def test_get_client_environment_name_with_kaggle(self): - os.environ["KAGGLE_CONTAINER_NAME"] = "test_container_name" - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.KAGGLE_NOTEBOOK.name) - - def test_get_client_environment_name_with_hosted_notebook(self): - os.environ["DL_PATH"] = "test_dl_path" - os.environ["USER"] = "jupyter" - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.HOSTED_NOTEBOOK.name) - - def test_get_client_environment_name_with_hosted_dlvm(self): - os.environ["DL_PATH"] = "test_dl_path" - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.DLVM.name) - - @mock.patch.object(google_api_client, "_is_module_present", autospec=True) - @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) - def test_get_client_environment_name_with_hosted_unknown( - self, mock_getenv, mock_modules): - mock_getenv.return_value = None - mock_modules.return_value = {} - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.UNKNOWN.name) - - @mock.patch.object(google_api_client, "_is_module_present", autospec=True) - @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) - def test_get_client_environment_name_with_hosted_colab( - self, mock_getenv, mock_modules): - mock_getenv.return_value = None - mock_modules.return_value = True - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.COLAB.name) - - @mock.patch.object(google_api_client, "_is_module_present", autospec=True) - @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) - def test_get_client_environment_name_with_hosted_dl_container( - self, mock_getenv, mock_modules): - mock_getenv.return_value = None - mock_modules.side_effect = [False, True] - self.assertEqual( - google_api_client.get_client_environment_name(), - google_api_client.ClientEnvironment.DL_CONTAINER.name) - - def test_get_or_set_consent_status_rejected(self): - config_data = {} - config_data["telemetry_rejected"] = True - - # Create the config path if it does not already exist - os.makedirs(os.path.dirname(self._local_config_path), exist_ok=True) - - with open(self._local_config_path, "w") as config_json: - json.dump(config_data, config_json) - - self.assertFalse(google_api_client.get_or_set_consent_status()) - - def test_get_or_set_consent_status_verified(self): - config_data = {} - config_data["notification_version"] = version.__version__ - - # Create the config path if it does not already exist - os.makedirs(os.path.dirname(self._local_config_path), exist_ok=True) - - with open(self._local_config_path, "w") as config_json: - json.dump(config_data, config_json) - - self.assertTrue(google_api_client.get_or_set_consent_status()) - - def test_get_or_set_consent_status_notify_user(self): - if os.path.exists(self._local_config_path): - os.remove(self._local_config_path) - - self.assertTrue(google_api_client.get_or_set_consent_status()) - - with open(self._local_config_path) as config_json: - config_data = json.load(config_json) - self.assertDictContainsSubset( - config_data, {"notification_version": version.__version__}) - - @mock.patch.object(google_api_client, - "get_or_set_consent_status", autospec=True) - def test_TFCloudHttpRequest_with_rejected_consent( - self, mock_consent_status): - - mock_consent_status.return_value = False - http_request = google_api_client.TFCloudHttpRequest( - googleapiclient_http.HttpMockSequence([({"status": "200"}, "{}")]), - object(), - "fake_uri", - ) - self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) - self.assertIn("user-agent", http_request.headers) - self.assertDictEqual( - {"user-agent": f"tf-cloud/{version.__version__} ()"}, - http_request.headers) - - @mock.patch.object(google_api_client, - "get_or_set_consent_status", autospec=True) - @mock.patch.object(google_api_client, - "get_client_environment_name", autospec=True) - def test_TFCloudHttpRequest_with_consent( - self, mock_get_env_name, mock_consent_status): - - mock_consent_status.return_value = True - mock_get_env_name.return_value = "TEST_ENV" - google_api_client.TFCloudHttpRequest.set_telemetry_dict({}) - http_request = google_api_client.TFCloudHttpRequest( - googleapiclient_http.HttpMockSequence([({"status": "200"}, "{}")]), - object(), - "fake_uri", - ) - self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) - self.assertIn("user-agent", http_request.headers) - - header_comment = "client_environment:TEST_ENV;" - full_header = f"tf-cloud/{version.__version__} ({header_comment})" - - self.assertDictEqual({"user-agent": full_header}, http_request.headers) - - @mock.patch.object(google_api_client, - "get_or_set_consent_status", autospec=True) - @mock.patch.object(google_api_client, - "get_client_environment_name", autospec=True) - def test_TFCloudHttpRequest_with_additional_metrics( - self, mock_get_env_name, mock_consent_status): - - google_api_client.TFCloudHttpRequest.set_telemetry_dict( - {"TEST_KEY1": "TEST_VALUE1"}) - mock_consent_status.return_value = True - mock_get_env_name.return_value = "TEST_ENV" - http_request = google_api_client.TFCloudHttpRequest( - googleapiclient_http.HttpMockSequence([({"status": "200"}, "{}")]), - object(), - "fake_uri", - ) - self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) - self.assertIn("user-agent", http_request.headers) - - header_comment = "TEST_KEY1:TEST_VALUE1;client_environment:TEST_ENV;" - full_header = f"tf-cloud/{version.__version__} ({header_comment})" - - self.assertDictEqual({"user-agent": full_header}, http_request.headers) - - # Verify when telemetry dict is refreshed it is used in new http request - google_api_client.TFCloudHttpRequest.set_telemetry_dict( - {"TEST_KEY2": "TEST_VALUE2"}) - mock_consent_status.return_value = True - mock_get_env_name.return_value = "TEST_ENV" - http_request = google_api_client.TFCloudHttpRequest( - googleapiclient_http.HttpMockSequence([({"status": "200"}, "{}")]), - object(), - "fake_uri", - ) - - header_comment = "TEST_KEY2:TEST_VALUE2;client_environment:TEST_ENV;" - full_header = f"tf-cloud/{version.__version__} ({header_comment})" - - self.assertDictEqual({"user-agent": full_header}, http_request.headers) + def setUp(self): + super(GoogleApiClientTest, self).setUp() + self.addCleanup(mock.patch.stopall) + + # Setting wait time to 1 sec to speed up the tests execution. + google_api_client._POLL_INTERVAL_IN_SECONDS = 1 + self._project_id = "project-a" + self._job_id = "job_id" + + self.mock_discovery_build = mock.patch.object( + discovery, "build", autospec=True).start() + self.mock_apiclient = mock.Mock() + self.mock_discovery_build.return_value = self.mock_apiclient + self.mock_request = mock.Mock() + self.mock_apiclient.projects().jobs().get.return_value = self.mock_request + self.mock_apiclient.projects().jobs( + ).cancel.return_value = self.mock_request + self._local_config_path = os.path.join(self.get_temp_dir(), "config.json") + google_api_client._LOCAL_CONFIG_PATH = self._local_config_path + + # TODO(b/177023448) Remove mock on logging.error here and below. + def test_wait_for_aip_training_job_completion_non_blocking_success(self): + self.mock_request.execute.return_value = { + "state": "SUCCEEDED", + } + status = google_api_client.wait_for_aip_training_job_completion( + self._job_id, self._project_id) + self.assertTrue(status) + self.mock_request.execute.assert_called_once() + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().get.assert_called_with(name=job_name) + + def test_wait_for_aip_training_job_completion_non_blocking_cancelled(self): + self.mock_request.execute.return_value = { + "state": "CANCELLED", + } + status = google_api_client.wait_for_aip_training_job_completion( + self._job_id, self._project_id) + self.assertTrue(status) + self.mock_request.execute.assert_called_once() + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().get.assert_called_with(name=job_name) + + def test_wait_for_aip_training_job_completion_non_blocking_failed(self): + self.mock_request.execute.return_value = { + "state": "FAILED", + "errorMessage": "test_error_message" + } + status = google_api_client.wait_for_aip_training_job_completion( + self._job_id, self._project_id) + self.assertFalse(status) + self.mock_request.execute.assert_called_once() + + def test_wait_for_aip_training_job_completion_multiple_checks_success(self): + self.mock_request.execute.side_effect = [{ + "state": "PREPARING" + }, { + "state": "RUNNING" + }, { + "state": "SUCCEEDED" + }] + status = google_api_client.wait_for_aip_training_job_completion( + self._job_id, self._project_id) + self.assertTrue(status) + self.assertEqual(3, self.mock_request.execute.call_count) + + def test_wait_for_aip_training_job_completion_multiple_checks_failed(self): + self.mock_request.execute.side_effect = [{ + "state": "PREPARING" + }, { + "state": "RUNNING" + }, { + "state": "FAILED", + "errorMessage": "test_error_message" + }] + status = google_api_client.wait_for_aip_training_job_completion( + self._job_id, self._project_id) + self.assertFalse(status) + self.assertEqual(3, self.mock_request.execute.call_count) + + def test_is_aip_training_job_running_with_completed_job(self): + self.mock_request.execute.side_effect = [{ + "state": "SUCCEEDED" + }, { + "state": "CANCELLED" + }, { + "state": "FAILED", + "errorMessage": "test_error_message" + }] + succeeded_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertFalse(succeeded_status) + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().get.assert_called_with(name=job_name) + cancelled_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertFalse(cancelled_status) + failed_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertFalse(failed_status) + self.assertEqual(3, self.mock_request.execute.call_count) + + def test_is_aip_training_job_running_with_running_job(self): + self.mock_request.execute.side_effect = [{ + "state": "QUEUED" + }, { + "state": "PREPARING" + }, { + "state": "RUNNING" + }, { + "state": "CANCELLING" + }] + queued_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertTrue(queued_status) + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().get.assert_called_with(name=job_name) + preparing_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertTrue(preparing_status) + running_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertTrue(running_status) + canceling_status = google_api_client.is_aip_training_job_running( + self._job_id, self._project_id) + self.assertTrue(canceling_status) + self.assertEqual(4, self.mock_request.execute.call_count) + + def test_stop_aip_training_job_with_running_job(self): + self.mock_request.execute.return_value = {} + google_api_client.stop_aip_training_job(self._job_id, self._project_id) + + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().cancel.assert_called_with( + name=job_name) + + def test_stop_aip_training_job_with_completed_job(self): + self.mock_request.execute.side_effect = errors.HttpError( + httplib2.Response(info={"status": 400}), b"") + google_api_client.stop_aip_training_job(self._job_id, self._project_id) + + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + self.mock_apiclient.projects().jobs().cancel.assert_called_with( + name=job_name) + + def test_stop_aip_training_job_with_failing_request(self): + self.mock_request.execute.side_effect = errors.HttpError( + httplib2.Response(info={"status": 404}), b"") + + job_name = "projects/{}/jobs/{}".format(self._project_id, self._job_id) + with self.assertRaises(errors.HttpError): + google_api_client.stop_aip_training_job(self._job_id, self._project_id) + self.mock_apiclient.projects().jobs().cancel.assert_called_with( + name=job_name) + + def test_get_client_environment_name_with_kaggle(self): + os.environ["KAGGLE_CONTAINER_NAME"] = "test_container_name" + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.KAGGLE_NOTEBOOK.name) + + def test_get_client_environment_name_with_hosted_notebook(self): + os.environ["DL_PATH"] = "test_dl_path" + os.environ["USER"] = "jupyter" + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.HOSTED_NOTEBOOK.name) + + def test_get_client_environment_name_with_hosted_dlvm(self): + os.environ["DL_PATH"] = "test_dl_path" + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.DLVM.name) + + @mock.patch.object(google_api_client, "_is_module_present", autospec=True) + @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) + def test_get_client_environment_name_with_hosted_unknown( + self, mock_getenv, mock_modules): + mock_getenv.return_value = None + mock_modules.return_value = {} + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.UNKNOWN.name) + + @mock.patch.object(google_api_client, "_is_module_present", autospec=True) + @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) + def test_get_client_environment_name_with_hosted_colab( + self, mock_getenv, mock_modules): + mock_getenv.return_value = None + mock_modules.return_value = True + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.COLAB.name) + + @mock.patch.object(google_api_client, "_is_module_present", autospec=True) + @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) + def test_get_client_environment_name_with_hosted_dl_container( + self, mock_getenv, mock_modules): + mock_getenv.return_value = None + mock_modules.side_effect = [False, True] + self.assertEqual(google_api_client.get_client_environment_name(), + google_api_client.ClientEnvironment.DL_CONTAINER.name) + + def test_get_or_set_consent_status_rejected(self): + config_data = {} + config_data["telemetry_rejected"] = True + + # Create the config path if it does not already exist + os.makedirs(os.path.dirname(self._local_config_path), exist_ok=True) + + with open(self._local_config_path, "w") as config_json: + json.dump(config_data, config_json) + + self.assertFalse(google_api_client.get_or_set_consent_status()) + + def test_get_or_set_consent_status_verified(self): + config_data = {} + config_data["notification_version"] = version.__version__ + + # Create the config path if it does not already exist + os.makedirs(os.path.dirname(self._local_config_path), exist_ok=True) + + with open(self._local_config_path, "w") as config_json: + json.dump(config_data, config_json) + + self.assertTrue(google_api_client.get_or_set_consent_status()) + + def test_get_or_set_consent_status_notify_user(self): + if os.path.exists(self._local_config_path): + os.remove(self._local_config_path) + + self.assertTrue(google_api_client.get_or_set_consent_status()) + + with open(self._local_config_path) as config_json: + config_data = json.load(config_json) + self.assertDictContainsSubset( + config_data, {"notification_version": version.__version__}) + + @mock.patch.object( + google_api_client, "get_or_set_consent_status", autospec=True) + def test_TFCloudHttpRequest_with_rejected_consent(self, mock_consent_status): + + mock_consent_status.return_value = False + http_request = google_api_client.TFCloudHttpRequest( + googleapiclient_http.HttpMockSequence([({ + "status": "200" + }, "{}")]), + object(), + "fake_uri", + ) + self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) + self.assertIn("user-agent", http_request.headers) + self.assertDictEqual({"user-agent": f"tf-cloud/{version.__version__} ()"}, + http_request.headers) + + @mock.patch.object( + google_api_client, "get_or_set_consent_status", autospec=True) + @mock.patch.object( + google_api_client, "get_client_environment_name", autospec=True) + def test_TFCloudHttpRequest_with_consent(self, mock_get_env_name, + mock_consent_status): + + mock_consent_status.return_value = True + mock_get_env_name.return_value = "TEST_ENV" + google_api_client.TFCloudHttpRequest.set_telemetry_dict({}) + http_request = google_api_client.TFCloudHttpRequest( + googleapiclient_http.HttpMockSequence([({ + "status": "200" + }, "{}")]), + object(), + "fake_uri", + ) + self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) + self.assertIn("user-agent", http_request.headers) + + header_comment = "client_environment:TEST_ENV;" + full_header = f"tf-cloud/{version.__version__} ({header_comment})" + + self.assertDictEqual({"user-agent": full_header}, http_request.headers) + + @mock.patch.object( + google_api_client, "get_or_set_consent_status", autospec=True) + @mock.patch.object( + google_api_client, "get_client_environment_name", autospec=True) + def test_TFCloudHttpRequest_with_additional_metrics(self, mock_get_env_name, + mock_consent_status): + + google_api_client.TFCloudHttpRequest.set_telemetry_dict( + {"TEST_KEY1": "TEST_VALUE1"}) + mock_consent_status.return_value = True + mock_get_env_name.return_value = "TEST_ENV" + http_request = google_api_client.TFCloudHttpRequest( + googleapiclient_http.HttpMockSequence([({ + "status": "200" + }, "{}")]), + object(), + "fake_uri", + ) + self.assertIsInstance(http_request, googleapiclient_http.HttpRequest) + self.assertIn("user-agent", http_request.headers) + + header_comment = "TEST_KEY1:TEST_VALUE1;client_environment:TEST_ENV;" + full_header = f"tf-cloud/{version.__version__} ({header_comment})" + + self.assertDictEqual({"user-agent": full_header}, http_request.headers) + + # Verify when telemetry dict is refreshed it is used in new http request + google_api_client.TFCloudHttpRequest.set_telemetry_dict( + {"TEST_KEY2": "TEST_VALUE2"}) + mock_consent_status.return_value = True + mock_get_env_name.return_value = "TEST_ENV" + http_request = google_api_client.TFCloudHttpRequest( + googleapiclient_http.HttpMockSequence([({ + "status": "200" + }, "{}")]), + object(), + "fake_uri", + ) + + header_comment = "TEST_KEY2:TEST_VALUE2;client_environment:TEST_ENV;" + full_header = f"tf-cloud/{version.__version__} ({header_comment})" + + self.assertDictEqual({"user-agent": full_header}, http_request.headers) + + @mock.patch.object(google_api_client, "_is_module_present", autospec=True) + @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) + def test_is_explainable_ai_sdk_installed_false(self, mock_getenv, + mock_modules): + mock_getenv.return_value = None + mock_modules.return_value = False + self.assertEqual(google_api_client.is_explainable_ai_sdk_installed(), False) + + @mock.patch.object(google_api_client, "_is_module_present", autospec=True) + @mock.patch.object(google_api_client, "_get_env_variable", autospec=True) + def test_is_explainable_ai_sdk_installed_true(self, mock_getenv, + mock_modules): + mock_getenv.return_value = None + mock_modules.return_value = True + self.assertEqual(google_api_client.is_explainable_ai_sdk_installed(), True) if __name__ == "__main__": - tf.test.main() + tf.test.main()