diff --git a/app_utils/chat.py b/app_utils/chat.py new file mode 100644 index 0000000..8312912 --- /dev/null +++ b/app_utils/chat.py @@ -0,0 +1,82 @@ +import json +import re +from typing import Dict, Any + +import requests +import streamlit as st +from snowflake.connector import SnowflakeConnection + +API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" + + +@st.cache_data(ttl=60, show_spinner=False) +def send_message( + _conn: SnowflakeConnection, semantic_model: str, messages: list[dict[str, str]] +) -> Dict[str, Any]: + """ + Calls the REST API with a list of messages and returns the response. + Args: + _conn: SnowflakeConnection, used to grab the token for auth. + messages: list of chat messages to pass to the Analyst API. + semantic_model: stringified YAML of the semantic model. + + Returns: The raw ChatMessage response from Analyst. + """ + request_body = { + "messages": messages, + "semantic_model": semantic_model, + } + + if st.session_state["sis"]: + import _snowflake + + resp = _snowflake.send_snow_api_request( # type: ignore + "POST", + f"/api/v2/cortex/analyst/message", + {}, + {}, + request_body, + {}, + 30000, + ) + if resp["status"] < 400: + json_resp: Dict[str, Any] = json.loads(resp["content"]) + return json_resp + else: + err_body = json.loads(resp["content"]) + if "message" in err_body: + # Certain errors have a message payload with a link to the github repo, which we should remove. + error_msg = re.sub( + r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*", + "", + err_body["message"], + ) + raise ValueError(error_msg) + raise ValueError(err_body) + + else: + host = st.session_state.host_name + resp = requests.post( + API_ENDPOINT.format( + HOST=host, + ), + json=request_body, + headers={ + "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] + "Content-Type": "application/json", + }, + ) + if resp.status_code < 400: + json_resp: Dict[str, Any] = resp.json() + return json_resp + else: + err_body = json.loads(resp.text) + if "message" in err_body: + # Certain errors have a message payload with a link to the github repo, which we should remove. + error_msg = re.sub( + r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*", + "", + err_body["message"], + ) + raise ValueError(error_msg) + raise ValueError(err_body) diff --git a/journeys/iteration.py b/journeys/iteration.py index ad35532..79175b7 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional import pandas as pd -import requests import sqlglot import streamlit as st from snowflake.connector import ProgrammingError, SnowflakeConnection @@ -11,6 +10,7 @@ from streamlit_extras.row import row from streamlit_extras.stylable_container import stylable_container +from app_utils.chat import send_message from app_utils.shared_utils import ( GeneratorAppScreen, SnowflakeStage, @@ -36,11 +36,6 @@ yaml_to_semantic_model, ) from semantic_model_generator.protos import semantic_model_pb2 -from semantic_model_generator.snowflake_utils.env_vars import ( - SNOWFLAKE_ACCOUNT_LOCATOR, - SNOWFLAKE_HOST, - SNOWFLAKE_USER, -) from semantic_model_generator.validate_model import validate @@ -67,64 +62,6 @@ def pretty_print_sql(sql: str) -> str: return formatted_sql -API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message" - - -@st.cache_data(ttl=60, show_spinner=False) -def send_message( - _conn: SnowflakeConnection, messages: list[dict[str, str]] -) -> Dict[str, Any]: - """ - Calls the REST API with a list of messages and returns the response. - Args: - _conn: SnowflakeConnection, used to grab the token for auth. - messages: list of chat messages to pass to the Analyst API. - - Returns: The raw ChatMessage response from Analyst. - """ - request_body = { - "messages": messages, - "semantic_model": proto_to_yaml(st.session_state.semantic_model), - } - - if st.session_state["sis"]: - import _snowflake - - resp = _snowflake.send_snow_api_request( # type: ignore - "POST", - f"/api/v2/cortex/analyst/message", - {}, - {}, - request_body, - {}, - 30000, - ) - if resp["status"] < 400: - json_resp: Dict[str, Any] = json.loads(resp["content"]) - return json_resp - else: - raise Exception(f"Failed request with status {resp['status']}: {resp}") - else: - host = st.session_state.host_name - resp = requests.post( - API_ENDPOINT.format( - HOST=host, - ), - json=request_body, - headers={ - "Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr] - "Content-Type": "application/json", - }, - ) - if resp.status_code < 400: - json_resp: Dict[str, Any] = resp.json() - return json_resp - else: - raise Exception( - f"Failed request with status {resp.status_code}: {resp.text}" - ) - - def process_message(_conn: SnowflakeConnection, prompt: str) -> None: """Processes a message and adds the response to the chat.""" user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]} @@ -139,14 +76,21 @@ def process_message(_conn: SnowflakeConnection, prompt: str) -> None: if st.session_state.multiturn else [user_message] ) - response = send_message(_conn=_conn, messages=request_messages) - content = response["message"]["content"] - # Grab the request ID from the response and stash it in the chat message object. - request_id = response["request_id"] - display_content(conn=_conn, content=content, request_id=request_id) - st.session_state.messages.append( - {"role": "analyst", "content": content, "request_id": request_id} - ) + try: + response = send_message( + _conn=_conn, + semantic_model=proto_to_yaml(st.session_state.semantic_model), + messages=request_messages, + ) + content = response["message"]["content"] + # Grab the request ID from the response and stash it in the chat message object. + request_id = response["request_id"] + display_content(conn=_conn, content=content, request_id=request_id) + st.session_state.messages.append( + {"role": "analyst", "content": content, "request_id": request_id} + ) + except ValueError as e: + st.error(e) def show_expr_for_ref(message_index: int) -> None: diff --git a/semantic_model_generator/tests/validate_model_test.py b/semantic_model_generator/tests/validate_model_test.py index f9601b1..524f77a 100644 --- a/semantic_model_generator/tests/validate_model_test.py +++ b/semantic_model_generator/tests/validate_model_test.py @@ -1,284 +1,40 @@ -import tempfile -from unittest import mock -from unittest.mock import MagicMock, patch +import json +from unittest.mock import patch, MagicMock -import pytest -from strictyaml import DuplicateKeysDisallowed, YAMLValidationError +from snowflake.connector import SnowflakeConnection -from semantic_model_generator.data_processing.proto_utils import proto_to_yaml -from semantic_model_generator.tests.samples import validate_yamls -from semantic_model_generator.validate_model import validate_from_local_path +from semantic_model_generator.validate_model import validate -@pytest.fixture -def mock_snowflake_connection(): - """Fixture to mock the snowflake_connection function.""" - with patch( - "semantic_model_generator.snowflake_utils.snowflake_connector.snowflake_connection" - ) as mock: - mock.return_value = MagicMock() - yield mock +@patch("semantic_model_generator.validate_model.send_message") +def test_validate_success(mock_send_message): + # Mock the response from send_message to simulate a successful response + mock_send_message.return_value = {} + # Call the validate function + conn = MagicMock(spec=SnowflakeConnection) + yaml_str = "valid_yaml_content" + result = validate(yaml_str, conn) -@pytest.fixture -def temp_valid_yaml_file_flow_style(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._VALID_YAML_FLOW_STYLE) - tmp.flush() # Ensure all data is written to the file - yield tmp.name + assert result is None -@pytest.fixture -def temp_valid_yaml_file(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._VALID_YAML) - tmp.flush() # Ensure all data is written to the file - yield tmp.name - - -@pytest.fixture -def temp_valid_yaml_file_long_vqr_context(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._VALID_YAML_LONG_VQR_CONTEXT) - tmp.flush() # Ensure all data is written to the file - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_formatting_file(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_FORMATTING) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_uppercase_file(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_UPPERCASE_DEFAULT_AGG) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_unmatched_quote_file(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_UNMATCHED_QUOTE) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_incorrect_dtype(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_INCORRECT_DATA_TYPE) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_too_long_context(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_TOO_LONG_CONTEXT) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_valid_yaml_with_verified_query(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._VALID_YAML_WITH_SINGLE_VERIFIED_QUERY) - tmp.flush() - yield tmp.name - - -@pytest.fixture -def temp_invalid_yaml_duplicate_verified_queries(): - """Create a temporary YAML file with the test data.""" - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(validate_yamls._INVALID_YAML_DUPLICATE_VERIFIED_QUERIES) - tmp.flush() - yield tmp.name - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_valid_yaml_flow_style( - mock_logger, temp_valid_yaml_file_flow_style, mock_snowflake_connection -): - validate_from_local_path(temp_valid_yaml_file_flow_style, mock_snowflake_connection) - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_valid_yaml(mock_logger, temp_valid_yaml_file, mock_snowflake_connection): - validate_from_local_path(temp_valid_yaml_file, mock_snowflake_connection) - - expected_log_call_1 = mock.call.info("Successfully validated!") - expected_log_call_2 = mock.call.info("Checking logical table: ALIAS") - expected_log_call_3 = mock.call.info("Validated logical table: ALIAS") - assert ( - expected_log_call_1 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - assert ( - expected_log_call_2 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - assert ( - expected_log_call_3 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - snowflake_query_one = "WITH __ALIAS AS (SELECT ALIAS, ZIP_CODE FROM AUTOSQL_DATASET_BIRD_V2.ADDRESS.ALIAS) SELECT * FROM __ALIAS LIMIT 1" - snowflake_query_two = "WITH __AREA_CODE AS (SELECT ZIP_CODE, AREA_CODE FROM AUTOSQL_DATASET_BIRD_V2.ADDRESS.AREA_CODE) SELECT * FROM __AREA_CODE LIMIT 1" - assert any( - snowflake_query_one in str(call) - for call in mock_snowflake_connection.mock_calls - ), "Query not executed" - assert any( - snowflake_query_two in str(call) - for call in mock_snowflake_connection.mock_calls - ), "Query not executed" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_valid_yaml_with_long_vqr_context( - mock_logger, temp_valid_yaml_file_long_vqr_context, mock_snowflake_connection -): - validate_from_local_path( - temp_valid_yaml_file_long_vqr_context, mock_snowflake_connection - ) - - expected_log_call_1 = mock.call.info("Successfully validated!") - expected_log_call_2 = mock.call.info("Checking logical table: ALIAS") - expected_log_call_3 = mock.call.info("Validated logical table: ALIAS") - assert ( - expected_log_call_1 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - assert ( - expected_log_call_2 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - assert ( - expected_log_call_3 in mock_logger.mock_calls - ), "Expected log message not found in logger calls" - snowflake_query_one = "WITH __ALIAS AS (SELECT ALIAS, ZIP_CODE FROM AUTOSQL_DATASET_BIRD_V2.ADDRESS.ALIAS) SELECT * FROM __ALIAS LIMIT 1" - snowflake_query_two = "WITH __AREA_CODE AS (SELECT ZIP_CODE, AREA_CODE FROM AUTOSQL_DATASET_BIRD_V2.ADDRESS.AREA_CODE) SELECT * FROM __AREA_CODE LIMIT 1" - assert any( - snowflake_query_one in str(call) - for call in mock_snowflake_connection.mock_calls - ), "Query not executed" - assert any( - snowflake_query_two in str(call) - for call in mock_snowflake_connection.mock_calls - ), "Query not executed" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_formatting( - mock_logger, temp_invalid_yaml_formatting_file, mock_snowflake_connection -): - with pytest.raises(DuplicateKeysDisallowed): - validate_from_local_path( - temp_invalid_yaml_formatting_file, mock_snowflake_connection - ) - - expected_log_call = mock.call.info("Successfully validated!") - assert ( - expected_log_call not in mock_logger.mock_calls - ), "Unexpected log message found in logger calls" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_uppercase( - mock_logger, temp_invalid_yaml_uppercase_file, mock_snowflake_connection -): - with pytest.raises( - YAMLValidationError, match=".*when expecting one of: aggregation_type_unknown.*" - ): - validate_from_local_path( - temp_invalid_yaml_uppercase_file, mock_snowflake_connection - ) - - expected_log_call = mock.call.info("Successfully validated!") - assert ( - expected_log_call not in mock_logger.mock_calls - ), "Unexpected log message found in logger calls" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_missing_quote( - mock_logger, temp_invalid_yaml_unmatched_quote_file, mock_snowflake_connection -): - with pytest.raises(YAMLValidationError) as exc_info: - validate_from_local_path( - temp_invalid_yaml_unmatched_quote_file, mock_snowflake_connection - ) - - expected_error_fragment = "name can only contain letters, underscores, decimal digits (0-9), and dollar signs ($)." - assert expected_error_fragment in str(exc_info.value), "Unexpected error message" - - expected_log_call = mock.call.info("Successfully validated!") - - assert ( - expected_log_call not in mock_logger.mock_calls - ), "Unexpected log message found in logger calls" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_incorrect_datatype( - mock_logger, temp_invalid_yaml_incorrect_dtype, mock_snowflake_connection -): - with pytest.raises(ValueError) as exc_info: - validate_from_local_path( - temp_invalid_yaml_incorrect_dtype, mock_snowflake_connection - ) - - expected_error = "Unable to validate your semantic model. Error = We do not support object datatypes in the semantic model. Col ZIP_CODE has data type OBJECT. Please remove this column from your semantic model or flatten it to non-object type." - - assert expected_error in str(exc_info.value), "Unexpected error message" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_too_long_context( - mock_logger, temp_invalid_yaml_too_long_context, mock_snowflake_connection -): - account_name = "snowflake test" - with pytest.raises(ValueError) as exc_info: - validate_from_local_path(temp_invalid_yaml_too_long_context, account_name) - - expected_error = ( - "Your semantic model is too large. " - "Passed size is 164952 characters. " - "We need you to remove 41032 characters in your semantic model. Please check: \n" - " (1) If you have long descriptions that can be truncated. \n" - " (2) If you can remove some columns that are not used within your tables. \n" - " (3) If you have extra tables you do not need." - ) - - assert expected_error in str(exc_info.value), "Unexpected error message" - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_valid_yaml_many_sample_values(mock_logger, mock_snowflake_connection): - yaml = proto_to_yaml(validate_yamls._VALID_YAML_MANY_SAMPLE_VALUES) - with tempfile.NamedTemporaryFile(mode="w", delete=True) as tmp: - tmp.write(yaml) - tmp.flush() - assert validate_from_local_path(tmp.name, mock_snowflake_connection) is None - - -@mock.patch("semantic_model_generator.validate_model.logger") -def test_invalid_yaml_duplicate_verified_queries( - mock_logger, temp_invalid_yaml_duplicate_verified_queries, mock_snowflake_connection -): - with pytest.raises( - YAMLValidationError, - match=r"Duplicate verified query found\.\n in \"semantic model\", line \d+, column \d+:\n verified_queries:\n \^ \(line: \d+\)\ndaily cumulative expenses in 2023 dec\n in \"semantic model\", line \d+, column \d+:\n verified_by: renee\n \^ \(line: \d+\)", - ): - validate_from_local_path( - temp_invalid_yaml_duplicate_verified_queries, mock_snowflake_connection +@patch("semantic_model_generator.validate_model.send_message") +def test_validate_error(mock_send_message): + # Mock the response from send_message to simulate an error response + mock_send_message.return_value = { + "error": json.dumps( + { + "message": "This YAML is missing a name. Please use https://github.com/Snowflake-Labs/semantic-model-generator.*" + } ) + } + + # Call the validate function and assert that it raises a ValueError + conn = MagicMock(spec=SnowflakeConnection) + yaml_str = "invalid_yaml_content" + try: + validate(yaml_str, conn) + except ValueError as e: + # Verify that the error message is as expected + assert str(e) == "This YAML is missing a name." diff --git a/semantic_model_generator/validate_model.py b/semantic_model_generator/validate_model.py index 2c8c1b6..e8fec3b 100644 --- a/semantic_model_generator/validate_model.py +++ b/semantic_model_generator/validate_model.py @@ -1,14 +1,8 @@ -from loguru import logger -from snowflake.connector import SnowflakeConnection +import json +import re -from semantic_model_generator.data_processing.cte_utils import ( - context_to_column_format, - expand_all_logical_tables_as_ctes, - generate_select, - validate_all_cols, -) -from semantic_model_generator.data_processing.proto_utils import yaml_to_semantic_model -from semantic_model_generator.validate.context_length import validate_context_length +from snowflake.connector import SnowflakeConnection +from app_utils.chat import send_message def load_yaml(yaml_path: str) -> str: @@ -24,46 +18,20 @@ def load_yaml(yaml_path: str) -> str: def validate(yaml_str: str, conn: SnowflakeConnection) -> None: """ - For now, validate just ensures that the yaml is correctly formatted and we can parse into our protos. + We perform pseudo-validation by issuing a request to Cortex Analyst with the YAML string as-is, and determining + whether the request is successful. We don't currently have an explicit validation endpoint available, but validation + is run at inference time, so this is a reasonable proxy. + + This is done in order to remove the need to sync validation logic locally between these codepaths and Analyst. yaml_str: yaml content in string format. conn: SnowflakeConnection Snowflake connection to pass in - - TODO: ensure that all expressions are valid by running a query containing all columns and expressions. """ - model = yaml_to_semantic_model(yaml_str) - # Validate the context length doesn't exceed max we can support. - validate_context_length(model, throw_error=True) - - model_in_column_format = context_to_column_format(model) - - for table in model_in_column_format.tables: - logger.info(f"Checking logical table: {table.name}") - try: - validate_all_cols(table) - sqls = generate_select(table, 1) - # Run the query. - # TODO: some expr maybe expensive if contains aggregations or window functions. Move to EXPLAIN? - for sql in sqls: - _ = conn.cursor().execute(sql) - except Exception as e: - raise ValueError(f"Unable to validate your semantic model. Error = {e}") - logger.info(f"Validated logical table: {table.name}") - - for vq in model.verified_queries: - logger.info(f"Checking verified queries for: {vq.question}") - try: - vqr_with_ctes = expand_all_logical_tables_as_ctes( - vq.sql, model_in_column_format - ) - # Run the query - _ = conn.cursor().execute(vqr_with_ctes) - except Exception as e: - raise ValueError(f"Fail to validate your verified query. Error = {e}") - logger.info(f"Validated verified query: {vq.question}") - - logger.info("Successfully validated!") + dummy_request = [ + {"role": "user", "content": [{"type": "text", "text": "SMG app validation"}]} + ] + send_message(conn, yaml_str, dummy_request) def validate_from_local_path(yaml_path: str, conn: SnowflakeConnection) -> None: