diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 383675283b..071481ab5b 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -21,6 +21,7 @@ ## New additions * Added `--retain-comments` option to `snow sql` command to allow passing comments to Snowflake. * Added `--replace` and `--if-not-exists` options to `snow object create` command. +* Added support for event sharing, which now can be specified under the `telemetry` section of an application entity. Two fields are supported: `authorize_event_sharing` and `optional_shared_events`. * `snow stage copy` supports `--recursive` flag to copy local files and subdirectories recursively to stage. Including glob support. diff --git a/src/snowflake/cli/_plugins/connection/util.py b/src/snowflake/cli/_plugins/connection/util.py index 9def61261c..2cd76acffb 100644 --- a/src/snowflake/cli/_plugins/connection/util.py +++ b/src/snowflake/cli/_plugins/connection/util.py @@ -17,7 +17,9 @@ import json import logging import os -from typing import Optional +from enum import Enum +from functools import lru_cache +from typing import Any, Dict, Optional from click.exceptions import ClickException from snowflake.connector import SnowflakeConnection @@ -25,12 +27,6 @@ log = logging.getLogger(__name__) -REGIONLESS_QUERY = """ - select value['value'] as REGIONLESS from table(flatten( - input => parse_json(SYSTEM$BOOTSTRAP_DATA_REQUEST()), - path => 'clientParamsInfo' - )) where value['name'] = 'UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT'; -""" ALLOWLIST_QUERY = "SELECT SYSTEM$ALLOWLIST()" SNOWFLAKE_DEPLOYMENT = "SNOWFLAKE_DEPLOYMENT" @@ -54,6 +50,53 @@ def __init__(self, host: str | None): ) +class UIParameter(Enum): + ENABLE_REGIONLESS_REDIRECT = "UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT" + EVENT_SHARING_V2 = "ENABLE_EVENT_SHARING_V2_IN_THE_SAME_ACCOUNT" + ENFORCE_MANDATORY_FILTERS = ( + "ENFORCE_MANDATORY_FILTERS_FOR_SAME_ACCOUNT_INSTALLATION" + ) + + +def get_ui_parameter( + conn: SnowflakeConnection, parameter: UIParameter, default: Any +) -> str: + """ + Returns the value of a single UI parameter. + If the parameter is not found, the default value is returned. + """ + + if not isinstance(parameter, UIParameter): + raise ValueError("Parameter must be a UIParameters enum") + + ui_parameters = get_ui_parameters(conn) + return ui_parameters.get(parameter, default) + + +@lru_cache(maxsize=None) +def get_ui_parameters(conn: SnowflakeConnection) -> Dict[UIParameter, Any]: + """ + Returns the UI parameters from the SYSTEM$BOOTSTRAP_DATA_REQUEST function + """ + + parameters_to_fetch = sorted( + [param.value for param in UIParameter.__members__.values()] + ) + + query = f""" + select value['value']::string as PARAM_VALUE, value['name']::string as PARAM_NAME from table(flatten( + input => parse_json(SYSTEM$BOOTSTRAP_DATA_REQUEST()), + path => 'clientParamsInfo' + )) where value['name'] in ('{"', '".join(parameters_to_fetch)}'); + """ + + *_, cursor = conn.execute_string(query, cursor_class=DictCursor) + + return { + UIParameter(row["PARAM_NAME"]): row["PARAM_VALUE"] for row in cursor.fetchall() + } + + def is_regionless_redirect(conn: SnowflakeConnection) -> bool: """ Determines if the deployment this connection refers to uses @@ -62,8 +105,12 @@ def is_regionless_redirect(conn: SnowflakeConnection) -> bool: assume it's regionless, as this is true for most production deployments. """ try: - *_, cursor = conn.execute_string(REGIONLESS_QUERY, cursor_class=DictCursor) - return cursor.fetchone()["REGIONLESS"].lower() == "true" + return ( + get_ui_parameter( + conn, UIParameter.ENABLE_REGIONLESS_REDIRECT, "true" + ).lower() + == "true" + ) except: log.warning( "Cannot determine regionless redirect; assuming True.", exc_info=True diff --git a/src/snowflake/cli/_plugins/nativeapp/artifacts.py b/src/snowflake/cli/_plugins/nativeapp/artifacts.py index d880c75d9f..c23e87893b 100644 --- a/src/snowflake/cli/_plugins/nativeapp/artifacts.py +++ b/src/snowflake/cli/_plugins/nativeapp/artifacts.py @@ -164,7 +164,7 @@ def put(self, src: Path, dest: Path, dest_is_dir: bool) -> None: if src_is_dir: # mark all subdirectories of this source as directories so that we can # detect accidental clobbering - for (root, _, files) in os.walk(absolute_src, followlinks=True): + for root, _, files in os.walk(absolute_src, followlinks=True): canonical_subdir = Path(root).relative_to(absolute_src) canonical_dest_subdir = dest / canonical_subdir self._update_dest_is_dir(canonical_dest_subdir, is_dir=True) @@ -383,7 +383,7 @@ def _expand_artifact_mapping( if absolute_src.is_dir() and expand_directories: # both src and dest are directories, and expanding directories was requested. Traverse src, and map each # file to the dest directory - for (root, subdirs, files) in os.walk(absolute_src, followlinks=True): + for root, subdirs, files in os.walk(absolute_src, followlinks=True): relative_root = Path(root).relative_to(absolute_src) for name in itertools.chain(subdirs, files): src_file_for_output = src_for_output / relative_root / name @@ -683,7 +683,7 @@ def build_bundle( "No artifacts mapping found in project definition, nothing to do." ) - for (absolute_src, absolute_dest) in bundle_map.all_mappings( + for absolute_src, absolute_dest in bundle_map.all_mappings( absolute=True, expand_directories=False ): symlink_or_copy(absolute_src, absolute_dest, deploy_root=deploy_root) @@ -765,3 +765,34 @@ def find_version_info_in_manifest_file( patch_number = int(version_info[patch_field]) return version_name, patch_number + + +def find_mandatory_events_in_manifest_file( + deploy_root: Path, +) -> List[str]: + """ + Find mandatory events, if available, in the manifest.yml file. + Mandatory events can be found under this section in the manifest.yml file: + + configuration: + telemetry_event_definitions: + - type: ERRORS_AND_WARNINGS + sharing: MANDATORY + - type: DEBUG_LOGS + sharing: OPTIONAL + """ + manifest_content = find_and_read_manifest_file(deploy_root=deploy_root) + + mandatory_events: List[str] = [] + + configuration_section = manifest_content.get("configuration", None) + if configuration_section: + telemetry_section = configuration_section.get( + "telemetry_event_definitions", None + ) + if telemetry_section: + for event in telemetry_section: + if event.get("sharing", None) == "MANDATORY": + mandatory_events.append(event["type"]) + + return mandatory_events diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/application.py b/src/snowflake/cli/_plugins/nativeapp/entities/application.py index 19c17382c4..1b4cc3754b 100644 --- a/src/snowflake/cli/_plugins/nativeapp/entities/application.py +++ b/src/snowflake/cli/_plugins/nativeapp/entities/application.py @@ -10,7 +10,14 @@ import typer from click import ClickException, UsageError from pydantic import Field, field_validator -from snowflake.cli._plugins.connection.util import make_snowsight_url +from snowflake.cli._plugins.connection.util import ( + UIParameter, + get_ui_parameter, + make_snowsight_url, +) +from snowflake.cli._plugins.nativeapp.artifacts import ( + find_mandatory_events_in_manifest_file, +) from snowflake.cli._plugins.nativeapp.common_flags import ( ForceOption, InteractiveOption, @@ -27,6 +34,9 @@ ApplicationPackageEntity, ApplicationPackageEntityModel, ) +from snowflake.cli._plugins.nativeapp.entities.models.event_sharing_telemetry import ( + EventSharingTelemetry, +) from snowflake.cli._plugins.nativeapp.exceptions import ( ApplicationPackageDoesNotExistError, NoEventTableForAccount, @@ -73,7 +83,7 @@ to_identifier, unquote_identifier, ) -from snowflake.connector import DictCursor, ProgrammingError +from snowflake.connector import DictCursor, ProgrammingError, SnowflakeConnection # Reasons why an `alter application ... upgrade` might fail UPGRADE_RESTRICTION_CODES = { @@ -97,6 +107,10 @@ class ApplicationEntityModel(EntityModelBase): title="Whether to enable debug mode when using a named stage to create an application object", default=None, ) + telemetry: Optional[EventSharingTelemetry] = Field( + title="Telemetry configuration for the application", + default=None, + ) @field_validator("identifier") @classmethod @@ -427,6 +441,71 @@ def get_objects_owned_by_application(self) -> List[ApplicationOwnedObject]: ).fetchall() return [{"name": row[1], "type": row[2]} for row in results] + def _should_authorize_event_sharing( + self, + install_method: SameAccountInstallMethod, + connection: SnowflakeConnection, + deploy_root: str, + ) -> Optional[bool]: + """ + Logic to determine whether event sharing should be authorized or not. + If the return value is None, it means that authorize_event_sharing should not be updated. + If the return value is True/False, it means that authorize_event_sharing should be set to True/False respectively. + """ + + model = self._entity_model + workspace_ctx = self._workspace_ctx + console = workspace_ctx.console + project_root = workspace_ctx.project_root + is_dev_mode = install_method.is_dev_mode + + mandatory_events_found = ( + len(find_mandatory_events_in_manifest_file(project_root / deploy_root)) > 0 + ) + event_sharing_enabled = ( + get_ui_parameter(connection, UIParameter.EVENT_SHARING_V2, "true").lower() + == "true" + ) + event_sharing_enforced = ( + get_ui_parameter( + connection, UIParameter.ENFORCE_MANDATORY_FILTERS, "true" + ).lower() + == "true" + ) + authorize_event_sharing = ( + model.telemetry and model.telemetry.authorize_event_sharing + ) + + if not event_sharing_enabled: + # We cannot set AUTHORIZE_TELEMETRY_EVENT_SHARING to True or False if event sharing is not enabled, + # so we will ignore the field in both cases, but warn only if it is set to True + if authorize_event_sharing: + console.warning( + "WARNING: Same-account event sharing is not enabled in your account, therefore, application telemetry field will be ignored." + ) + return None + + if event_sharing_enabled and not event_sharing_enforced: + # if event sharing is enabled but not enforced yet, warn about future enforcement + if mandatory_events_found and not authorize_event_sharing: + console.warning( + "WARNING: Mandatory events are present in the manifest file, but event sharing is not authorized in the application telemetry field. This will soon be required to set in order to deploy applications." + ) + + if event_sharing_enabled and event_sharing_enforced: + if mandatory_events_found: + if authorize_event_sharing is None and is_dev_mode: + console.warning( + "WARNING: Mandatory events are present in the manifest file. Automatically authorizing event sharing in dev mode. To suppress this warning, please add authorize_event_sharing field in the application telemetry section." + ) + return True + elif not authorize_event_sharing: + raise ClickException( + "Mandatory events are present in the manifest file, but event sharing is not authorized in the application telemetry field. This is required to deploy applications." + ) + + return authorize_event_sharing + def create_or_upgrade_app( self, package: ApplicationPackageEntity, @@ -443,6 +522,14 @@ def create_or_upgrade_app( sql_executor = get_sql_executor() with sql_executor.use_role(self.role): + authorize_event_sharing = self._should_authorize_event_sharing( + install_method, + sql_executor._conn, # noqa: SLF001 + package_model.deploy_root, + ) + optional_shared_events = ( + model.telemetry and model.telemetry.optional_shared_events + ) # 1. Need to use a warehouse to create an application object with sql_executor.use_warehouse(self.warehouse): @@ -470,6 +557,15 @@ def create_or_upgrade_app( ) print_messages(console, upgrade_cursor) + if authorize_event_sharing is not None: + sql_executor.execute_query( + f"alter application {app_name} set AUTHORIZE_TELEMETRY_EVENT_SHARING = {str(authorize_event_sharing).upper()}" + ) + if optional_shared_events: + sql_executor.execute_query( + f"""alter application {app_name} set shared telemetry events ('{"', '".join(optional_shared_events)}')""" + ) + if install_method.is_dev_mode: # if debug_mode is present (controlled), ensure it is up-to-date if debug_mode is not None: @@ -517,18 +613,27 @@ def create_or_upgrade_app( ) debug_mode_clause = f"debug_mode = {initial_debug_mode}" + authorize_telemetry_clause = "" + if authorize_event_sharing is not None: + authorize_telemetry_clause = f" AUTHORIZE_TELEMETRY_EVENT_SHARING = {str(authorize_event_sharing).upper()}" + using_clause = install_method.using_clause(stage_fqn) create_cursor = sql_executor.execute_query( dedent( f"""\ create application {self.name} - from application package {package.name} {using_clause} {debug_mode_clause} + from application package {package.name} {using_clause} {debug_mode_clause}{authorize_telemetry_clause} comment = {SPECIAL_COMMENT} """ ), ) print_messages(console, create_cursor) + if optional_shared_events: + sql_executor.execute_query( + f"""alter application {app_name} set shared telemetry events ('{"', '".join(optional_shared_events)}')""" + ) + # hooks always executed after a create or upgrade self.execute_post_deploy_hooks() diff --git a/src/snowflake/cli/_plugins/nativeapp/entities/models/event_sharing_telemetry.py b/src/snowflake/cli/_plugins/nativeapp/entities/models/event_sharing_telemetry.py new file mode 100644 index 0000000000..d3190eb31e --- /dev/null +++ b/src/snowflake/cli/_plugins/nativeapp/entities/models/event_sharing_telemetry.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +from click import ClickException +from pydantic import Field, field_validator, model_validator +from snowflake.cli.api.project.schemas.updatable_model import UpdatableModel + + +class EventSharingTelemetry(UpdatableModel): + authorize_event_sharing: Optional[bool] = Field( + title="Whether to authorize Snowflake to share application usage data with application package provider. This automatically enables the sharing of required telemetry events.", + default=None, + ) + optional_shared_events: Optional[List[str]] = Field( + title="List of optional telemetry events that application owner would like to share with application package provider.", + default=None, + ) + + @model_validator(mode="after") + @classmethod + def validate_authorize_event_sharing(cls, value): + if value.optional_shared_events and not value.authorize_event_sharing: + raise ClickException( + "telemetry.authorize_event_sharing is required to be true in order to use telemetry.optional_shared_events." + ) + return value + + @field_validator("optional_shared_events") + @classmethod + def transform_artifacts( + cls, original_shared_events: Optional[List[str]] + ) -> Optional[List[str]]: + if original_shared_events is None: + return None + + # make sure that each event is made of letters and underscores: + for event in original_shared_events: + if not event.isalpha() and not event.replace("_", "").isalpha(): + raise ClickException( + f"Event {event} from optional_shared_events field is not a valid event name." + ) + + # make sure events are unique: + if len(original_shared_events) != len(set(original_shared_events)): + raise ClickException( + "Events in optional_shared_events field must be unique." + ) + + return original_shared_events diff --git a/tests/nativeapp/test_artifacts.py b/tests/nativeapp/test_artifacts.py index bef9ae2d98..8d4096df90 100644 --- a/tests/nativeapp/test_artifacts.py +++ b/tests/nativeapp/test_artifacts.py @@ -29,6 +29,7 @@ SourceNotFoundError, TooManyFilesError, build_bundle, + find_mandatory_events_in_manifest_file, find_version_info_in_manifest_file, resolve_without_follow, symlink_or_copy, @@ -100,10 +101,9 @@ def verify_mappings( expected_mappings: Dict[ Union[str, Path], Optional[Union[str, Path, List[str], List[Path]]] ], - expected_deploy_paths: Dict[ - Union[str, Path], Optional[Union[str, Path, List[str], List[Path]]] - ] - | None = None, + expected_deploy_paths: ( + Dict[Union[str, Path], Optional[Union[str, Path, List[str], List[Path]]]] | None + ) = None, **kwargs, ): def normalize_expected_dest( @@ -1415,3 +1415,50 @@ def test_find_version_info_in_manifest_file(version_name, patch_name): assert p is None else: assert p == int(patch_name) + + +@pytest.mark.parametrize( + "configuration_section, expected_output", + [ + [ + {}, + [], + ], + [ + { + "telemetry_event_definitions": [ + {"type": "USAGE_LOGS", "sharing": "MANDATORY"} + ] + }, + ["USAGE_LOGS"], + ], + [ + { + "telemetry_event_definitions": [ + {"type": "ERRORS_AND_WARNINGS", "sharing": "MANDATORY"}, + {"type": "DEBUG_LOGS", "sharing": "OPTIONAL"}, + ] + }, + ["ERRORS_AND_WARNINGS"], + ], + [ + { + "telemetry_event_definitions": [ + {"type": "ERRORS_AND_WARNINGS", "sharing": "MANDATORY"}, + {"type": "ALL", "sharing": "MANDATORY"}, + ] + }, + ["ERRORS_AND_WARNINGS", "ALL"], + ], + ], +) +def test_find_mandatory_events_in_manifest_file(configuration_section, expected_output): + manifest_contents = {"manifest_version": 1, "version": {"name": "v1", "patch": 1}} + manifest_contents["configuration"] = configuration_section + + deploy_root_structure = {"manifest.yml": safe_dump(manifest_contents)} + with temp_local_dir(deploy_root_structure) as deploy_root: + assert ( + find_mandatory_events_in_manifest_file(deploy_root=deploy_root) + == expected_output + ) diff --git a/tests/nativeapp/test_event_sharing.py b/tests/nativeapp/test_event_sharing.py new file mode 100644 index 0000000000..13d8a2607e --- /dev/null +++ b/tests/nativeapp/test_event_sharing.py @@ -0,0 +1,1086 @@ +# Copyright (c) 2024 Snowflake Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from textwrap import dedent +from unittest import mock +from unittest.mock import MagicMock + +import pytest +from click import ClickException +from snowflake.cli._plugins.connection.util import UIParameter +from snowflake.cli._plugins.nativeapp.constants import ( + SPECIAL_COMMENT, +) +from snowflake.cli._plugins.nativeapp.entities.application import ( + ApplicationEntity, + ApplicationEntityModel, +) +from snowflake.cli._plugins.nativeapp.entities.application_package import ( + ApplicationPackageEntity, + ApplicationPackageEntityModel, +) +from snowflake.cli._plugins.nativeapp.policy import ( + AllowAlwaysPolicy, + AskAlwaysPolicy, + DenyAlwaysPolicy, + PolicyBase, +) +from snowflake.cli._plugins.nativeapp.same_account_install_method import ( + SameAccountInstallMethod, +) +from snowflake.cli._plugins.stage.diff import DiffResult +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext +from snowflake.cli.api.console import cli_console as cc +from snowflake.cli.api.console.abc import AbstractConsole +from snowflake.cli.api.project.definition_manager import DefinitionManager + +from tests.nativeapp.factories import ( + ApplicationEntityModelFactory, + ApplicationPackageEntityModelFactory, + ProjectV2Factory, +) +from tests.nativeapp.patch_utils import ( + mock_connection, +) +from tests.nativeapp.utils import ( + APP_ENTITY_GET_EXISTING_APP_INFO, + GET_UI_PARAMETERS, + SQL_EXECUTOR_EXECUTE, + mock_execute_helper, +) +from tests.testing_utils.fixtures import MockConnectionCtx + +allow_always_policy = AllowAlwaysPolicy() +ask_always_policy = AskAlwaysPolicy() +deny_always_policy = DenyAlwaysPolicy() +test_manifest_contents = dedent( + """\ + manifest_version: 1 + + version: + name: dev + label: "Dev Version" + comment: "Default version used for development. Override for actual deployment." + + artifacts: + setup_script: setup.sql + readme: README.md + + configuration: + log_level: INFO + trace_level: ALWAYS +""" +) + +test_manifest_with_mandatory_events = dedent( + """\ + manifest_version: 1 + + version: + name: dev + label: "Dev Version" + comment: "Default version used for development. Override for actual deployment." + + artifacts: + setup_script: setup.sql + readme: README.md + + configuration: + telemetry_event_definitions: + - type: ERRORS_AND_WARNINGS + sharing: MANDATORY + - type: DEBUG_LOGS + sharing: OPTIONAL +""" +) + + +def _create_or_upgrade_app( + policy: PolicyBase, + install_method: SameAccountInstallMethod, + is_interactive: bool = False, + package_id: str = "app_pkg", + app_id: str = "myapp", + console: AbstractConsole | None = None, +): + dm = DefinitionManager() + pd = dm.project_definition + pkg_model: ApplicationPackageEntityModel = pd.entities[package_id] + app_model: ApplicationEntityModel = pd.entities[app_id] + ctx = WorkspaceContext( + console=console or cc, + project_root=dm.project_root, + get_default_role=lambda: "mock_role", + get_default_warehouse=lambda: "mock_warehouse", + ) + app = ApplicationEntity(app_model, ctx) + pkg = ApplicationPackageEntity(pkg_model, ctx) + stage_fqn = f"{pkg_model.fqn.name}.{pkg_model.stage}" + + def drop_application_before_upgrade(cascade: bool = False): + app.drop_application_before_upgrade( + console=console or cc, + app_name=app_model.fqn.identifier, + app_role=app_model.meta.role, + policy=policy, + is_interactive=is_interactive, + cascade=cascade, + ) + + pkg.action_bundle(action_ctx=ActionContext(get_entity=lambda *args: None)) + + return app.create_or_upgrade_app( + package_model=pkg_model, + stage_fqn=stage_fqn, + install_method=install_method, + drop_application_before_upgrade=drop_application_before_upgrade, + ) + + +def _setup_project( + app_pkg_role="package_role", + app_pkg_warehouse="pkg_warehouse", + app_role="app_role", + app_warehouse="app_warehouse", + setup_sql_contents="CREATE OR ALTER VERSIONED SCHEMA core;", + readme_contents="\n", + manifest_contents=test_manifest_contents, + authorize_event_sharing=None, + optional_shared_events=None, +): + telemetry = {} + if authorize_event_sharing is not None: + telemetry["authorize_event_sharing"] = authorize_event_sharing + if optional_shared_events is not None: + telemetry["optional_shared_events"] = optional_shared_events + ProjectV2Factory( + pdf__entities=dict( + app_pkg=ApplicationPackageEntityModelFactory( + identifier="app_pkg", + meta={"role": app_pkg_role, "warehouse": app_pkg_warehouse}, + ), + myapp=ApplicationEntityModelFactory( + identifier="myapp", + fromm__target="app_pkg", + meta={"role": app_role, "warehouse": app_warehouse}, + telemetry=(telemetry), + ), + ), + files={ + "setup.sql": setup_sql_contents, + "README.md": readme_contents, + "manifest.yml": manifest_contents, + }, + ) + + +def _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + authorize_telemetry_flag=None, + optional_shared_events=None, + is_prod=False, + is_upgrade=False, +): + if is_upgrade: + return _setup_mocks_for_upgrade_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + authorize_telemetry_flag=authorize_telemetry_flag, + optional_shared_events=optional_shared_events, + is_prod=is_prod, + ) + else: + return _setup_mocks_for_create_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + authorize_telemetry_flag=authorize_telemetry_flag, + optional_shared_events=optional_shared_events, + is_prod=is_prod, + ) + + +def _setup_mocks_for_create_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + authorize_telemetry_flag=None, + optional_shared_events=None, + is_prod=False, +): + mock_get_existing_app_info.return_value = None + + authorize_telemetry_clause = "" + if authorize_telemetry_flag is not None: + authorize_telemetry_clause = ( + f" AUTHORIZE_TELEMETRY_EVENT_SHARING = {authorize_telemetry_flag}".upper() + ) + install_clause = "using @app_pkg.app_src.stage debug_mode = True" + if is_prod: + install_clause = " " + + calls = [ + ( + mock_cursor([("old_role",)], []), + mock.call("select current_role()"), + ), + (None, mock.call("use role app_role")), + ( + mock_cursor([("old_wh",)], []), + mock.call("select current_warehouse()"), + ), + (None, mock.call("use warehouse app_warehouse")), + ( + mock_cursor([("app_role",)], []), + mock.call("select current_role()"), + ), + (None, mock.call("use role package_role")), + ( + None, + mock.call( + "grant install, develop on application package app_pkg to role app_role" + ), + ), + ( + None, + mock.call("grant usage on schema app_pkg.app_src to role app_role"), + ), + ( + None, + mock.call("grant read on stage app_pkg.app_src.stage to role app_role"), + ), + (None, mock.call("use role app_role")), + ( + None, + mock.call( + dedent( + f"""\ + create application myapp + from application package app_pkg {install_clause}{authorize_telemetry_clause} + comment = {SPECIAL_COMMENT} + """ + ) + ), + ), + ] + + if optional_shared_events is not None: + calls.append( + ( + None, + mock.call( + f"""alter application myapp set shared telemetry events ('{"', '".join(optional_shared_events)}')""" + ), + ), + ) + + calls.extend( + [ + (None, mock.call("use warehouse old_wh")), + (None, mock.call("use role old_role")), + ] + ) + side_effects, mock_execute_query_expected = mock_execute_helper(calls) + mock_execute_query.side_effect = side_effects + return mock_execute_query_expected + + +def _setup_mocks_for_upgrade_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + authorize_telemetry_flag=None, + optional_shared_events=None, + is_prod=False, +): + mock_get_existing_app_info.return_value = {"comment": "GENERATED_BY_SNOWFLAKECLI"} + install_clause = "using @app_pkg.app_src.stage" + if is_prod: + install_clause = "" + + calls = [ + ( + mock_cursor([("old_role",)], []), + mock.call("select current_role()"), + ), + (None, mock.call("use role app_role")), + ( + mock_cursor([("old_wh",)], []), + mock.call("select current_warehouse()"), + ), + (None, mock.call("use warehouse app_warehouse")), + (None, mock.call(f"alter application myapp upgrade {install_clause}")), + ] + + if authorize_telemetry_flag is not None: + calls.append( + ( + None, + mock.call( + f"alter application myapp set AUTHORIZE_TELEMETRY_EVENT_SHARING = {str(authorize_telemetry_flag).upper()}" + ), + ), + ) + + if optional_shared_events is not None: + calls.append( + ( + None, + mock.call( + f"""alter application myapp set shared telemetry events ('{"', '".join(optional_shared_events)}')""" + ), + ), + ) + + calls.extend( + [ + (None, mock.call("use warehouse old_wh")), + (None, mock.call("use role old_role")), + ] + ) + side_effects, mock_execute_query_expected = mock_execute_helper(calls) + mock_execute_query.side_effect = side_effects + return mock_execute_query_expected + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [ + test_manifest_contents, + test_manifest_with_mandatory_events, + ], +) +@pytest.mark.parametrize( + "authorize_event_sharing", + [ + False, + None, + ], +) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize( + "is_upgrade", + [False, True], +) +def test_event_sharing_disabled_no_change_to_current_behavior( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [ + test_manifest_contents, + test_manifest_with_mandatory_events, + ], +) +@pytest.mark.parametrize("authorize_event_sharing", [True]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_event_sharing_disabled_but_we_add_event_sharing_flag_in_project_definition_file( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=None, # treat it as unset + is_upgrade=is_upgrade, + ) + + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_has_calls( + [ + mock.call( + "WARNING: Same-account event sharing is not enabled in your account, therefore, application telemetry field will be ignored." + ) + ] + ) + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [ + test_manifest_contents, + ], +) +@pytest.mark.parametrize("authorize_event_sharing", [True, False, None]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_event_sharing_enabled_not_enforced_no_mandatory_events_then_flag_respected( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [True, False, None]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_event_sharing_enabled_with_mandatory_events_and_explicit_authorization_then_flag_respected_with_potential_warning( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + # Warn if the flag is not set, but there are mandatory events + if authorize_event_sharing: + mock_console.warning.assert_not_called() + else: + mock_console.warning.assert_has_calls( + [ + mock.call( + "WARNING: Mandatory events are present in the manifest file, but event sharing is not authorized in the application telemetry field. This will soon be required to set in order to deploy applications." + ) + ] + ) + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_contents], +) +@pytest.mark.parametrize("authorize_event_sharing", [True, False, None]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_enforced_events_sharing_with_no_mandatory_events_then_use_value_provided_for_authorization( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [True]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_enforced_events_sharing_with_mandatory_events_and_authorization_provided( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [False]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_enforced_events_sharing_with_mandatory_events_and_authorization_refused_then_error( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + with pytest.raises(ClickException) as e: + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert ( + e.value.message + == "Mandatory events are present in the manifest file, but event sharing is not authorized in the application telemetry field. This is required to deploy applications." + ) + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [None]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_enforced_events_sharing_with_mandatory_events_and_dev_mode_then_default_to_true_with_warning( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=True, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + expected_warning = "WARNING: Mandatory events are present in the manifest file. Automatically authorizing event sharing in dev mode. To suppress this warning, please add authorize_event_sharing field in the application telemetry section." + mock_console.warning.assert_has_calls([mock.call(expected_warning)]) + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [None]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_enforced_events_sharing_with_mandatory_events_and_authorization_not_specified_and_prod_mode_then_error( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=True, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + with pytest.raises(ClickException) as e: + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert ( + e.value.message + == "Mandatory events are present in the manifest file, but event sharing is not authorized in the application telemetry field. This is required to deploy applications." + ) + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events], +) +@pytest.mark.parametrize("authorize_event_sharing", [None, False]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False]) +def test_optional_shared_events_with_no_authorization_then_error( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + optional_shared_events=["DEBUG_LOGS"], + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + with pytest.raises(ClickException) as e: + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert ( + e.value.message + == "telemetry.authorize_event_sharing is required to be true in order to use telemetry.optional_shared_events." + ) + mock_console.warning.assert_not_called() + + +@mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) +@mock.patch(SQL_EXECUTOR_EXECUTE) +@mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "true", + UIParameter.ENFORCE_MANDATORY_FILTERS: "true", + }, +) +@pytest.mark.parametrize( + "manifest_contents", + [test_manifest_with_mandatory_events, test_manifest_contents], +) +@pytest.mark.parametrize("authorize_event_sharing", [True]) +@pytest.mark.parametrize( + "install_method", + [ + SameAccountInstallMethod.unversioned_dev(), + SameAccountInstallMethod.release_directive(), + ], +) +@pytest.mark.parametrize("is_upgrade", [False, True]) +def test_optional_shared_events_with_authorization_then_success( + mock_param, + mock_conn, + mock_execute_query, + mock_get_existing_app_info, + manifest_contents, + authorize_event_sharing, + install_method, + is_upgrade, + temp_dir, + mock_cursor, +): + optional_shared_events = ["DEBUG_LOGS", "ERRORS_AND_WARNINGS"] + mock_execute_query_expected = _setup_mocks_for_app( + mock_execute_query, + mock_cursor, + mock_get_existing_app_info, + is_prod=not install_method.is_dev_mode, + authorize_telemetry_flag=authorize_event_sharing, + is_upgrade=is_upgrade, + optional_shared_events=optional_shared_events, + ) + mock_conn.return_value = MockConnectionCtx() + mock_diff_result = DiffResult() + _setup_project( + manifest_contents=manifest_contents, + authorize_event_sharing=authorize_event_sharing, + optional_shared_events=optional_shared_events, + ) + assert not mock_diff_result.has_changes() + mock_console = MagicMock() + + _create_or_upgrade_app( + policy=MagicMock(), + install_method=install_method, + console=mock_console, + ) + + assert mock_execute_query.mock_calls == mock_execute_query_expected + mock_console.warning.assert_not_called() diff --git a/tests/nativeapp/test_run_processor.py b/tests/nativeapp/test_run_processor.py index 1d38eaea18..3349ecdf8a 100644 --- a/tests/nativeapp/test_run_processor.py +++ b/tests/nativeapp/test_run_processor.py @@ -20,6 +20,7 @@ import pytest import typer from click import UsageError +from snowflake.cli._plugins.connection.util import UIParameter from snowflake.cli._plugins.nativeapp.constants import ( LOOSE_FILES_MAGIC_VERSION, SPECIAL_COMMENT, @@ -46,7 +47,7 @@ SameAccountInstallMethod, ) from snowflake.cli._plugins.stage.diff import DiffResult -from snowflake.cli._plugins.workspace.context import WorkspaceContext +from snowflake.cli._plugins.workspace.context import ActionContext, WorkspaceContext from snowflake.cli._plugins.workspace.manager import WorkspaceManager from snowflake.cli.api.console import cli_console as cc from snowflake.cli.api.console.abc import AbstractConsole @@ -72,10 +73,10 @@ from tests.nativeapp.utils import ( APP_ENTITY_GET_EXISTING_APP_INFO, APP_PACKAGE_ENTITY_GET_EXISTING_VERSION_INFO, + GET_UI_PARAMETERS, SQL_EXECUTOR_EXECUTE, TYPER_CONFIRM, mock_execute_helper, - mock_snowflake_yml_file_v2, quoted_override_yml_file_v2, ) from tests.testing_utils.files_and_dirs import create_named_file @@ -84,6 +85,24 @@ allow_always_policy = AllowAlwaysPolicy() ask_always_policy = AskAlwaysPolicy() deny_always_policy = DenyAlwaysPolicy() +test_manifest_contents = dedent( + """\ + manifest_version: 1 + + version: + name: dev + label: "Dev Version" + comment: "Default version used for development. Override for actual deployment." + + artifacts: + setup_script: setup.sql + readme: README.md + + configuration: + log_level: INFO + trace_level: ALWAYS +""" +) def _get_wm(): @@ -123,6 +142,8 @@ def drop_application_before_upgrade(cascade: bool = False): cascade=cascade, ) + pkg.action_bundle(action_ctx=ActionContext(get_entity=lambda *args: None)) + return app.create_or_upgrade_app( package=pkg, stage_fqn=stage_fqn, @@ -131,11 +152,70 @@ def drop_application_before_upgrade(cascade: bool = False): ) +test_pdf = dedent( + """\ + definition_version: 2 + entities: + app_pkg: + type: application package + stage: app_src.stage + manifest: app/manifest.yml + artifacts: + - setup.sql + - src: app/manifest.yml + dest: manifest.yml + meta: + role: package_role + warehouse: pkg_warehouse + myapp: + type: application + debug: true + from: + target: app_pkg + meta: + role: app_role + warehouse: app_warehouse + """ +) + + +def setup_project_file(current_working_directory: str, pdf=None): + create_named_file( + file_name="snowflake.yml", + dir_name=current_working_directory, + contents=[pdf or test_pdf], + ) + + create_named_file( + file_name="manifest.yml", + dir_name=f"{current_working_directory}/app", + contents=[test_manifest_contents], + ) + create_named_file( + file_name="README.md", + dir_name=f"{current_working_directory}/app", + contents=["# This is readme"], + ) + + create_named_file( + file_name="setup.sql", + dir_name=current_working_directory, + contents=["-- hi"], + ) + + # Test create_dev_app with exception thrown trying to use the warehouse @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_w_warehouse_access_exception( - mock_conn, mock_execute, temp_dir, mock_cursor + mock_param, mock_conn, mock_execute, temp_dir, mock_cursor ): side_effects, expected = mock_execute_helper( [ @@ -165,12 +245,7 @@ def test_create_dev_app_w_warehouse_access_exception( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) assert not mock_diff_result.has_changes() @@ -191,8 +266,20 @@ def test_create_dev_app_w_warehouse_access_exception( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_create_new_w_no_additional_privileges( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): side_effects, expected = mock_execute_helper( [ @@ -226,17 +313,18 @@ def test_create_dev_app_create_new_w_no_additional_privileges( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2.replace("package_role", "app_role")], - ) + + setup_project_file(os.getcwd(), test_pdf.replace("package_role", "app_role")) assert not mock_diff_result.has_changes() - _create_or_upgrade_app( - policy=MagicMock(), install_method=SameAccountInstallMethod.unversioned_dev() - ) + try: + _create_or_upgrade_app( + policy=MagicMock(), + install_method=SameAccountInstallMethod.unversioned_dev(), + ) + except Exception as e: + print(mock_execute.mock_calls) + raise e assert mock_execute.mock_calls == expected @@ -244,6 +332,13 @@ def test_create_dev_app_create_new_w_no_additional_privileges( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "existing_app_info", [ @@ -257,6 +352,7 @@ def test_create_dev_app_create_new_w_no_additional_privileges( ], ) def test_create_or_upgrade_dev_app_with_warning( + mock_param, mock_conn, mock_execute, mock_get_existing_app_info, @@ -318,12 +414,7 @@ def test_create_or_upgrade_dev_app_with_warning( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2.replace("package_role", "app_role")], - ) + setup_project_file(os.getcwd(), test_pdf.replace("package_role", "app_role")) assert not mock_diff_result.has_changes() mock_console = mock.MagicMock() @@ -341,7 +432,15 @@ def test_create_or_upgrade_dev_app_with_warning( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_create_new_with_additional_privileges( + mock_param, mock_conn, mock_execute_query, mock_get_existing_app_info, @@ -400,12 +499,7 @@ def test_create_dev_app_create_new_with_additional_privileges( mock_execute_query.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) assert not mock_diff_result.has_changes() _create_or_upgrade_app( @@ -418,8 +512,20 @@ def test_create_dev_app_create_new_with_additional_privileges( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_create_new_w_missing_warehouse_exception( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): side_effects, expected = mock_execute_helper( [ @@ -456,12 +562,7 @@ def test_create_dev_app_create_new_w_missing_warehouse_exception( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2.replace("package_role", "app_role")], - ) + setup_project_file(os.getcwd(), test_pdf.replace("package_role", "app_role")) assert not mock_diff_result.has_changes() @@ -480,6 +581,13 @@ def test_create_dev_app_create_new_w_missing_warehouse_exception( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "comment, version", [ @@ -488,6 +596,7 @@ def test_create_dev_app_create_new_w_missing_warehouse_exception( ], ) def test_create_dev_app_incorrect_properties( + mock_param, mock_conn, mock_execute, mock_get_existing_app_info, @@ -522,12 +631,7 @@ def test_create_dev_app_incorrect_properties( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(ApplicationCreatedExternallyError): assert not mock_diff_result.has_changes() @@ -543,8 +647,20 @@ def test_create_dev_app_incorrect_properties( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_incorrect_owner( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): mock_get_existing_app_info.return_value = { "name": "MYAPP", @@ -581,12 +697,7 @@ def test_create_dev_app_incorrect_owner( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(ProgrammingError): assert not mock_diff_result.has_changes() @@ -601,9 +712,21 @@ def test_create_dev_app_incorrect_owner( # Test create_dev_app with existing application AND diff has no changes @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock.patch(SQL_EXECUTOR_EXECUTE) +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @mock_connection() def test_create_dev_app_no_diff_changes( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): mock_get_existing_app_info.return_value = { "name": "MYAPP", @@ -638,12 +761,7 @@ def test_create_dev_app_no_diff_changes( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) assert not mock_diff_result.has_changes() _create_or_upgrade_app( @@ -656,8 +774,20 @@ def test_create_dev_app_no_diff_changes( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_w_diff_changes( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): mock_get_existing_app_info.return_value = { "name": "MYAPP", @@ -692,12 +822,7 @@ def test_create_dev_app_w_diff_changes( mock_execute.side_effect = side_effects mock_diff_result = DiffResult(different=["setup.sql"]) - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) assert mock_diff_result.has_changes() _create_or_upgrade_app( @@ -710,8 +835,20 @@ def test_create_dev_app_w_diff_changes( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_recreate_w_missing_warehouse_exception( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): mock_get_existing_app_info.return_value = { "name": "MYAPP", @@ -747,12 +884,7 @@ def test_create_dev_app_recreate_w_missing_warehouse_exception( mock_execute.side_effect = side_effects mock_diff_result = DiffResult(different=["setup.sql"]) - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) assert mock_diff_result.has_changes() @@ -770,8 +902,20 @@ def test_create_dev_app_recreate_w_missing_warehouse_exception( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_create_new_quoted( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): side_effects, expected = mock_execute_helper( [ @@ -805,42 +949,36 @@ def test_create_dev_app_create_new_quoted( mock_execute.side_effect = side_effects mock_diff_result = DiffResult() - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[ - dedent( - """\ - definition_version: 2 - entities: - app_pkg: - type: application package - identifier: '"My Package"' - artifacts: - - setup.sql - - app/README.md - - src: app/streamlit/*.py - dest: ui/ - manifest: app/manifest.yml - stage: app_src.stage - meta: - role: app_role - post_deploy: - - sql_script: shared_content.sql - myapp: - type: application - identifier: '"My Application"' - debug: true - from: - target: app_pkg - meta: - role: app_role - warehouse: app_warehouse - """ - ) - ], + pdf_content = dedent( + """\ + definition_version: 2 + entities: + app_pkg: + type: application package + identifier: '"My Package"' + artifacts: + - setup.sql + - app/README.md + - src: app/manifest.yml + dest: manifest.yml + manifest: app/manifest.yml + stage: app_src.stage + meta: + role: app_role + post_deploy: + - sql_script: shared_content.sql + myapp: + type: application + identifier: '"My Application"' + debug: true + from: + target: app_pkg + meta: + role: app_role + warehouse: app_warehouse + """ ) + setup_project_file(os.getcwd(), pdf_content) assert not mock_diff_result.has_changes() _create_or_upgrade_app( @@ -853,8 +991,20 @@ def test_create_dev_app_create_new_quoted( @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO, return_value=None) @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_create_new_quoted_override( - mock_conn, mock_execute, mock_get_existing_app_info, temp_dir, mock_cursor + mock_param, + mock_conn, + mock_execute, + mock_get_existing_app_info, + temp_dir, + mock_cursor, ): side_effects, expected = mock_execute_helper( [ @@ -889,10 +1039,8 @@ def test_create_dev_app_create_new_quoted_override( mock_diff_result = DiffResult() current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2.replace("package_role", "app_role")], + setup_project_file( + current_working_directory, test_pdf.replace("package_role", "app_role") ) create_named_file( file_name="snowflake.local.yml", @@ -915,7 +1063,15 @@ def test_create_dev_app_create_new_quoted_override( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_recreate_app_when_orphaned( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -986,12 +1142,7 @@ def test_create_dev_app_recreate_app_when_orphaned( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy=MagicMock(), install_method=SameAccountInstallMethod.unversioned_dev() @@ -1008,7 +1159,15 @@ def test_create_dev_app_recreate_app_when_orphaned( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_recreate_app_when_orphaned_requires_cascade( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1096,12 +1255,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy=MagicMock(), install_method=SameAccountInstallMethod.unversioned_dev() @@ -1119,7 +1273,15 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_create_dev_app_recreate_app_when_orphaned_requires_cascade_unknown_objects( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1202,12 +1364,7 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade_unknown_obje mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy=MagicMock(), install_method=SameAccountInstallMethod.unversioned_dev() @@ -1218,11 +1375,18 @@ def test_create_dev_app_recreate_app_when_orphaned_requires_cascade_unknown_obje # Test upgrade app method for release directives AND throws warehouse error @mock.patch(SQL_EXECUTOR_EXECUTE) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param", [allow_always_policy, ask_always_policy, deny_always_policy] ) def test_upgrade_app_warehouse_error( - mock_conn, mock_execute, policy_param, temp_dir, mock_cursor + mock_param, mock_conn, mock_execute, policy_param, temp_dir, mock_cursor ): side_effects, expected = mock_execute_helper( [ @@ -1251,12 +1415,7 @@ def test_upgrade_app_warehouse_error( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(CouldNotUseObjectError): _create_or_upgrade_app( @@ -1271,10 +1430,18 @@ def test_upgrade_app_warehouse_error( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param", [allow_always_policy, ask_always_policy, deny_always_policy] ) def test_upgrade_app_incorrect_owner( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1313,12 +1480,7 @@ def test_upgrade_app_incorrect_owner( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(ProgrammingError): _create_or_upgrade_app( @@ -1333,10 +1495,18 @@ def test_upgrade_app_incorrect_owner( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param", [allow_always_policy, ask_always_policy, deny_always_policy] ) def test_upgrade_app_succeeds( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1369,12 +1539,7 @@ def test_upgrade_app_succeeds( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy=policy_param, @@ -1388,10 +1553,18 @@ def test_upgrade_app_succeeds( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param", [allow_always_policy, ask_always_policy, deny_always_policy] ) def test_upgrade_app_fails_generic_error( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1429,12 +1602,7 @@ def test_upgrade_app_fails_generic_error( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(ProgrammingError): _create_or_upgrade_app( @@ -1454,11 +1622,19 @@ def test_upgrade_app_fails_generic_error( f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=False ) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param, is_interactive_param, expected_code", [(deny_always_policy, False, 1), (ask_always_policy, True, 0)], ) def test_upgrade_app_fails_upgrade_restriction_error( + mock_param, mock_conn, mock_typer_confirm, mock_get_existing_app_info, @@ -1499,12 +1675,7 @@ def test_upgrade_app_fails_upgrade_restriction_error( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(typer.Exit): result = _create_or_upgrade_app( @@ -1519,7 +1690,15 @@ def test_upgrade_app_fails_upgrade_restriction_error( @mock.patch(SQL_EXECUTOR_EXECUTE) @mock.patch(APP_ENTITY_GET_EXISTING_APP_INFO) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) def test_versioned_app_upgrade_to_unversioned( + mock_param, mock_conn, mock_get_existing_app_info, mock_execute, @@ -1597,12 +1776,7 @@ def test_versioned_app_upgrade_to_unversioned( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy=AllowAlwaysPolicy(), @@ -1621,11 +1795,19 @@ def test_versioned_app_upgrade_to_unversioned( f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True ) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize( "policy_param, is_interactive_param", [(allow_always_policy, False), (ask_always_policy, True)], ) def test_upgrade_app_fails_drop_fails( + mock_param, mock_conn, mock_typer_confirm, mock_get_existing_app_info, @@ -1671,12 +1853,7 @@ def test_upgrade_app_fails_drop_fails( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) with pytest.raises(ProgrammingError): _create_or_upgrade_app( @@ -1694,8 +1871,16 @@ def test_upgrade_app_fails_drop_fails( f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True ) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize("policy_param", [allow_always_policy, ask_always_policy]) def test_upgrade_app_recreate_app( + mock_param, mock_conn, mock_typer_confirm, mock_get_existing_app_info, @@ -1767,12 +1952,7 @@ def test_upgrade_app_recreate_app( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) _create_or_upgrade_app( policy_param, @@ -1788,12 +1968,7 @@ def test_upgrade_app_recreate_app( return_value=None, ) def test_upgrade_app_from_version_throws_usage_error_one(mock_existing, temp_dir): - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) wm = _get_wm() with pytest.raises(UsageError): @@ -1818,12 +1993,7 @@ def test_upgrade_app_from_version_throws_usage_error_two( mock_existing, temp_dir, ): - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) wm = _get_wm() with pytest.raises(UsageError): @@ -1850,8 +2020,16 @@ def test_upgrade_app_from_version_throws_usage_error_two( f"snowflake.cli._plugins.nativeapp.policy.{TYPER_CONFIRM}", return_value=True ) @mock_connection() +@mock.patch( + GET_UI_PARAMETERS, + return_value={ + UIParameter.EVENT_SHARING_V2: "false", + UIParameter.ENFORCE_MANDATORY_FILTERS: "false", + }, +) @pytest.mark.parametrize("policy_param", [allow_always_policy, ask_always_policy]) def test_upgrade_app_recreate_app_from_version( + mock_param, mock_conn, mock_typer_confirm, mock_get_existing_app_info, @@ -1925,14 +2103,13 @@ def test_upgrade_app_recreate_app_from_version( mock_conn.return_value = MockConnectionCtx() mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) wm = _get_wm() + wm.perform_action( + "app_pkg", + EntityActions.BUNDLE, + ) wm.perform_action( "myapp", EntityActions.DEPLOY, @@ -1982,12 +2159,7 @@ def test_get_existing_version_info( ) mock_execute.side_effect = side_effects - current_working_directory = os.getcwd() - create_named_file( - file_name="snowflake.yml", - dir_name=current_working_directory, - contents=[mock_snowflake_yml_file_v2], - ) + setup_project_file(os.getcwd()) dm = DefinitionManager() pd = dm.project_definition diff --git a/tests/nativeapp/utils.py b/tests/nativeapp/utils.py index 1554dc65e4..47af0f1279 100644 --- a/tests/nativeapp/utils.py +++ b/tests/nativeapp/utils.py @@ -25,6 +25,7 @@ TYPER_PROMPT = "typer.prompt" ENTITIES_COMMON_MODULE = "snowflake.cli.api.entities.common" ENTITIES_UTILS_MODULE = "snowflake.cli.api.entities.utils" +PLUGIN_UTIL_MODULE = "snowflake.cli._plugins.connection.util" APPLICATION_PACKAGE_ENTITY_MODULE = ( "snowflake.cli._plugins.nativeapp.entities.application_package" ) @@ -61,6 +62,7 @@ SQL_EXECUTOR_EXECUTE = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_query" SQL_EXECUTOR_EXECUTE_QUERIES = f"{ENTITIES_COMMON_MODULE}.SqlExecutor._execute_queries" +GET_UI_PARAMETERS = f"{PLUGIN_UTIL_MODULE}.get_ui_parameters" SQL_FACADE_MODULE = "snowflake.cli._plugins.nativeapp.sf_facade" SQL_FACADE = f"{SQL_FACADE_MODULE}.SnowflakeSQLFacade" diff --git a/tests/streamlit/test_commands.py b/tests/streamlit/test_commands.py index 84968e5aab..f83e2086b3 100644 --- a/tests/streamlit/test_commands.py +++ b/tests/streamlit/test_commands.py @@ -17,10 +17,12 @@ from unittest import mock import pytest -from snowflake.cli._plugins.connection.util import REGIONLESS_QUERY +from snowflake.cli._plugins.connection.util import UIParameter from snowflake.cli._plugins.streamlit.manager import StreamlitManager from snowflake.cli.api.identifiers import FQN +from tests.nativeapp.utils import GET_UI_PARAMETERS + STREAMLIT_NAME = "test_streamlit" TEST_WAREHOUSE = "test_warehouse" @@ -63,8 +65,13 @@ def _put_query(source: str, dest: str): @mock.patch("snowflake.cli._plugins.connection.util.get_account") @mock.patch("snowflake.cli._plugins.streamlit.commands.typer") @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_only_streamlit_file( + mock_param, mock_connector, mock_typer, mock_get_account, @@ -77,7 +84,6 @@ def test_deploy_only_streamlit_file( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "my_account"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -107,7 +113,6 @@ def test_deploy_only_streamlit_file( """ ).strip(), "select system$get_snowsight_host()", - REGIONLESS_QUERY, ] mock_typer.launch.assert_not_called() @@ -115,8 +120,13 @@ def test_deploy_only_streamlit_file( @mock.patch("snowflake.cli._plugins.connection.util.get_account") @mock.patch("snowflake.cli._plugins.streamlit.commands.typer") @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_only_streamlit_file_no_stage( + mock_param, mock_connector, mock_typer, mock_get_account, @@ -129,7 +139,6 @@ def test_deploy_only_streamlit_file_no_stage( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "my_account"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -158,7 +167,6 @@ def test_deploy_only_streamlit_file_no_stage( """ ).strip(), "select system$get_snowsight_host()", - REGIONLESS_QUERY, ] mock_typer.launch.assert_not_called() @@ -166,8 +174,13 @@ def test_deploy_only_streamlit_file_no_stage( @mock.patch("snowflake.cli._plugins.connection.util.get_account") @mock.patch("snowflake.cli._plugins.streamlit.commands.typer") @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_with_empty_pages( + mock_param, mock_connector, mock_typer, mock_get_account, @@ -180,7 +193,6 @@ def test_deploy_with_empty_pages( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "my_account"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -211,7 +223,6 @@ def test_deploy_with_empty_pages( """ ).strip(), "select system$get_snowsight_host()", - REGIONLESS_QUERY, ] assert "Skipping empty directory: pages" in result.output @@ -219,8 +230,13 @@ def test_deploy_with_empty_pages( @mock.patch("snowflake.cli._plugins.connection.util.get_account") @mock.patch("snowflake.cli._plugins.streamlit.commands.typer") @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_only_streamlit_file_replace( + mock_param, mock_connector, mock_typer, mock_get_account, @@ -233,7 +249,6 @@ def test_deploy_only_streamlit_file_replace( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "my_account"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -263,7 +278,6 @@ def test_deploy_only_streamlit_file_replace( """ ).strip(), "select system$get_snowsight_host()", - REGIONLESS_QUERY, ] mock_typer.launch.assert_not_called() @@ -287,15 +301,24 @@ def test_artifacts_must_exists( @mock.patch("snowflake.cli._plugins.streamlit.commands.typer") @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_launch_browser( - mock_connector, mock_typer, mock_cursor, runner, mock_ctx, project_directory + mock_param, + mock_connector, + mock_typer, + mock_cursor, + runner, + mock_ctx, + project_directory, ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], ) @@ -313,15 +336,18 @@ def test_deploy_launch_browser( @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_and_environment_files( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -350,21 +376,23 @@ def test_deploy_streamlit_and_environment_files( """ ).strip(), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_and_pages_files( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -392,21 +420,23 @@ def test_deploy_streamlit_and_pages_files( """ ).strip(), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_all_streamlit_files( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -435,21 +465,23 @@ def test_deploy_all_streamlit_files( """ ).strip(), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, "select current_account_name()", ] @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_put_files_on_stage( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -480,21 +512,23 @@ def test_deploy_put_files_on_stage( """ ).strip(), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_all_streamlit_files_not_defaults( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -521,7 +555,6 @@ def test_deploy_all_streamlit_files_not_defaults( """ ).strip(), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @@ -529,8 +562,13 @@ def test_deploy_all_streamlit_files_not_defaults( @mock.patch("snowflake.connector.connect") @pytest.mark.parametrize("enable_streamlit_versioned_stage", [True, False]) @pytest.mark.parametrize("enable_streamlit_no_checkouts", [True, False]) +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_main_and_pages_files_experimental( + mock_param, mock_connector, mock_cursor, runner, @@ -543,7 +581,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -595,7 +632,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental( _put_query("environment.yml", f"{root_path}"), _put_query("pages/*", f"{root_path}/pages"), "select system$get_snowsight_host()", - REGIONLESS_QUERY, "select current_account_name()", ] if cmd is not None @@ -603,8 +639,13 @@ def test_deploy_streamlit_main_and_pages_files_experimental( @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_main_and_pages_files_experimental_double_deploy( + mock_param, mock_connector, mock_cursor, runner, @@ -615,7 +656,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental_double_deploy( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -633,7 +673,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental_double_deploy( ctx.cs = mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -661,15 +700,19 @@ def test_deploy_streamlit_main_and_pages_files_experimental_double_deploy( _put_query("environment.yml", f"{root_path}"), _put_query("pages/*", f"{root_path}/pages"), "select system$get_snowsight_host()", - REGIONLESS_QUERY, "select current_account_name()", ] @mock.patch("snowflake.connector.connect") @pytest.mark.parametrize("enable_streamlit_versioned_stage", [True, False]) +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_main_and_pages_files_experimental_no_stage( + mock_param, mock_connector, mock_cursor, runner, @@ -681,7 +724,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental_no_stage( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -721,21 +763,23 @@ def test_deploy_streamlit_main_and_pages_files_experimental_no_stage( _put_query("environment.yml", f"{root_path}"), _put_query("pages/*", f"{root_path}/pages"), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) @mock_streamlit_exists def test_deploy_streamlit_main_and_pages_files_experimental_replace( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -762,7 +806,6 @@ def test_deploy_streamlit_main_and_pages_files_experimental_replace( _put_query("environment.yml", f"{root_path}"), _put_query("pages/*", f"{root_path}/pages"), f"select system$get_snowsight_host()", - REGIONLESS_QUERY, f"select current_account_name()", ] @@ -818,12 +861,15 @@ def test_drop_streamlit(mock_connector, runner, mock_ctx): @mock.patch("snowflake.connector.connect") -def test_get_streamlit_url(mock_connector, mock_cursor, runner, mock_ctx): +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) +def test_get_streamlit_url(mock_param, mock_connector, mock_cursor, runner, mock_ctx): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -836,7 +882,6 @@ def test_get_streamlit_url(mock_connector, mock_cursor, runner, mock_ctx): assert result.exit_code == 0, result.output assert ctx.get_queries() == [ "select system$get_snowsight_host()", - REGIONLESS_QUERY, "select current_account_name()", ] @@ -898,14 +943,17 @@ def test_multiple_streamlit_raise_error_if_multiple_entities( @mock.patch("snowflake.connector.connect") +@mock.patch( + GET_UI_PARAMETERS, + return_value={UIParameter.ENABLE_REGIONLESS_REDIRECT: "false"}, +) def test_deploy_streamlit_with_comment_v2( - mock_connector, mock_cursor, runner, mock_ctx, project_directory + mock_param, mock_connector, mock_cursor, runner, mock_ctx, project_directory ): ctx = mock_ctx( mock_cursor( rows=[ {"SYSTEM$GET_SNOWSIGHT_HOST()": "https://snowsight.domain"}, - {"REGIONLESS": "false"}, {"CURRENT_ACCOUNT_NAME()": "https://snowsight.domain"}, ], columns=["SYSTEM$GET_SNOWSIGHT_HOST()"], @@ -935,7 +983,6 @@ def test_deploy_streamlit_with_comment_v2( """ ).strip(), "select system$get_snowsight_host()", - REGIONLESS_QUERY, "select current_account_name()", ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 8b3ef8c84e..19f8c61cef 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,15 +16,18 @@ import os from pathlib import Path from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest import snowflake.cli._plugins.snowpark.models import snowflake.cli._plugins.snowpark.package.utils from snowflake.cli._plugins.connection.util import ( LOCAL_DEPLOYMENT_REGION, + UIParameter, get_context, get_host_region, + get_ui_parameter, + get_ui_parameters, guess_regioned_host_from_allowlist, make_snowsight_url, ) @@ -33,6 +36,7 @@ from snowflake.cli.api.secure_path import SecurePath from snowflake.cli.api.utils import path_utils from snowflake.connector import SnowflakeConnection +from snowflake.connector.cursor import DictCursor from tests.test_data import test_data @@ -286,3 +290,119 @@ def test_get_context_local_non_regionless_gets_local_region( ) def test_get_host_region(host, expected): assert get_host_region(host) == expected + + +def test_get_ui_parameters_no_params(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [] + assert get_ui_parameters(connection) == {} + + assert connection.execute_string.has_calls([]) + + +expected_ui_params_query = f""" + select value['value']::string as PARAM_VALUE, value['name']::string as PARAM_NAME from table(flatten( + input => parse_json(SYSTEM$BOOTSTRAP_DATA_REQUEST()), + path => 'clientParamsInfo' + )) where value['name'] in ('ENABLE_EVENT_SHARING_V2_IN_THE_SAME_ACCOUNT', 'ENFORCE_MANDATORY_FILTERS_FOR_SAME_ACCOUNT_INSTALLATION', 'UI_SNOWSIGHT_ENABLE_REGIONLESS_REDIRECT'); + """ + + +def test_get_ui_parameters_no_param(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [] + assert get_ui_parameters(connection) == {} + + connection.execute_string.assert_called_with( + expected_ui_params_query, cursor_class=DictCursor + ) + + +def test_get_ui_parameters_one_param(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [ + { + "PARAM_NAME": UIParameter.ENABLE_REGIONLESS_REDIRECT.value, + "PARAM_VALUE": "true", + } + ] + assert get_ui_parameters(connection) == { + UIParameter.ENABLE_REGIONLESS_REDIRECT: "true" + } + + connection.execute_string.assert_called_with( + expected_ui_params_query, cursor_class=DictCursor + ) + + +def test_get_ui_parameters_multiple_params(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [ + { + "PARAM_NAME": UIParameter.ENABLE_REGIONLESS_REDIRECT.value, + "PARAM_VALUE": "true", + }, + { + "PARAM_NAME": UIParameter.EVENT_SHARING_V2.value, + "PARAM_VALUE": "false", + }, + ] + assert get_ui_parameters(connection) == { + UIParameter.ENABLE_REGIONLESS_REDIRECT: "true", + UIParameter.EVENT_SHARING_V2: "false", + } + + connection.execute_string.assert_called_with( + expected_ui_params_query, cursor_class=DictCursor + ) + + +def test_get_ui_parameter_with_value(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [ + { + "PARAM_NAME": UIParameter.ENABLE_REGIONLESS_REDIRECT.value, + "PARAM_VALUE": "true", + } + ] + assert ( + get_ui_parameter(connection, UIParameter.ENABLE_REGIONLESS_REDIRECT, "false") + == "true" + ) + + +def test_get_ui_parameter_with_empty_value_then_use_empty_value(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [ + { + "PARAM_NAME": UIParameter.ENABLE_REGIONLESS_REDIRECT.value, + "PARAM_VALUE": "", + } + ] + assert ( + get_ui_parameter(connection, UIParameter.ENABLE_REGIONLESS_REDIRECT, "false") + == "" + ) + + +def test_get_ui_parameter_with_no_value_then_use_default(): + connection = MagicMock() + cursor = MagicMock() + connection.execute_string.return_value = (None, cursor) + cursor.fetchall.return_value = [] + assert ( + get_ui_parameter(connection, UIParameter.ENABLE_REGIONLESS_REDIRECT, "false") + == "false" + ) diff --git a/tests/testing_utils/files_and_dirs.py b/tests/testing_utils/files_and_dirs.py index 5c5bd6bc2f..d1a55cbcc8 100644 --- a/tests/testing_utils/files_and_dirs.py +++ b/tests/testing_utils/files_and_dirs.py @@ -31,6 +31,7 @@ def create_temp_file(suffix: str, dir_name: str, contents: List[str]) -> str: def create_named_file(file_name: str, dir_name: str, contents: List[str]): file_path = os.path.join(dir_name, file_name) + os.makedirs(os.path.dirname(file_path), exist_ok=True) _write_to_file(file_path, contents) return file_path