Skip to content

Commit

Permalink
Call Analyst to perform validation (#187)
Browse files Browse the repository at this point in the history
We currently maintain copies of the validation logic in both the
internal Analyst codepaths as well as this OSS app. Often, the OSS app
can become out of date. Instead of performing validation locally, we
will simply call Analyst with the current YAML string, as it performs
validation at inference time. Any error returned is shown to the user.

The diff for this PR seems big but it's mostly deleting unnecessary code
+ tests.
  • Loading branch information
sfc-gh-cnivera authored Oct 23, 2024
1 parent 7a3afa7 commit 2f1f675
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 392 deletions.
82 changes: 82 additions & 0 deletions app_utils/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
import re
from typing import Dict, Any

import requests
import streamlit as st
from snowflake.connector import SnowflakeConnection

API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message"


@st.cache_data(ttl=60, show_spinner=False)
def send_message(
_conn: SnowflakeConnection, semantic_model: str, messages: list[dict[str, str]]
) -> Dict[str, Any]:
"""
Calls the REST API with a list of messages and returns the response.
Args:
_conn: SnowflakeConnection, used to grab the token for auth.
messages: list of chat messages to pass to the Analyst API.
semantic_model: stringified YAML of the semantic model.
Returns: The raw ChatMessage response from Analyst.
"""
request_body = {
"messages": messages,
"semantic_model": semantic_model,
}

if st.session_state["sis"]:
import _snowflake

resp = _snowflake.send_snow_api_request( # type: ignore
"POST",
f"/api/v2/cortex/analyst/message",
{},
{},
request_body,
{},
30000,
)
if resp["status"] < 400:
json_resp: Dict[str, Any] = json.loads(resp["content"])
return json_resp
else:
err_body = json.loads(resp["content"])
if "message" in err_body:
# Certain errors have a message payload with a link to the github repo, which we should remove.
error_msg = re.sub(
r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*",
"",
err_body["message"],
)
raise ValueError(error_msg)
raise ValueError(err_body)

else:
host = st.session_state.host_name
resp = requests.post(
API_ENDPOINT.format(
HOST=host,
),
json=request_body,
headers={
"Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr]
"Content-Type": "application/json",
},
)
if resp.status_code < 400:
json_resp: Dict[str, Any] = resp.json()
return json_resp
else:
err_body = json.loads(resp.text)
if "message" in err_body:
# Certain errors have a message payload with a link to the github repo, which we should remove.
error_msg = re.sub(
r"\s*Please use https://github\.com/Snowflake-Labs/semantic-model-generator.*",
"",
err_body["message"],
)
raise ValueError(error_msg)
raise ValueError(err_body)
88 changes: 16 additions & 72 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from typing import Any, Dict, List, Optional

import pandas as pd
import requests
import sqlglot
import streamlit as st
from snowflake.connector import ProgrammingError, SnowflakeConnection
from streamlit.delta_generator import DeltaGenerator
from streamlit_extras.row import row
from streamlit_extras.stylable_container import stylable_container

from app_utils.chat import send_message
from app_utils.shared_utils import (
GeneratorAppScreen,
SnowflakeStage,
Expand All @@ -36,11 +36,6 @@
yaml_to_semantic_model,
)
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.snowflake_utils.env_vars import (
SNOWFLAKE_ACCOUNT_LOCATOR,
SNOWFLAKE_HOST,
SNOWFLAKE_USER,
)
from semantic_model_generator.validate_model import validate


Expand All @@ -67,64 +62,6 @@ def pretty_print_sql(sql: str) -> str:
return formatted_sql


API_ENDPOINT = "https://{HOST}/api/v2/cortex/analyst/message"


@st.cache_data(ttl=60, show_spinner=False)
def send_message(
_conn: SnowflakeConnection, messages: list[dict[str, str]]
) -> Dict[str, Any]:
"""
Calls the REST API with a list of messages and returns the response.
Args:
_conn: SnowflakeConnection, used to grab the token for auth.
messages: list of chat messages to pass to the Analyst API.
Returns: The raw ChatMessage response from Analyst.
"""
request_body = {
"messages": messages,
"semantic_model": proto_to_yaml(st.session_state.semantic_model),
}

if st.session_state["sis"]:
import _snowflake

resp = _snowflake.send_snow_api_request( # type: ignore
"POST",
f"/api/v2/cortex/analyst/message",
{},
{},
request_body,
{},
30000,
)
if resp["status"] < 400:
json_resp: Dict[str, Any] = json.loads(resp["content"])
return json_resp
else:
raise Exception(f"Failed request with status {resp['status']}: {resp}")
else:
host = st.session_state.host_name
resp = requests.post(
API_ENDPOINT.format(
HOST=host,
),
json=request_body,
headers={
"Authorization": f'Snowflake Token="{_conn.rest.token}"', # type: ignore[union-attr]
"Content-Type": "application/json",
},
)
if resp.status_code < 400:
json_resp: Dict[str, Any] = resp.json()
return json_resp
else:
raise Exception(
f"Failed request with status {resp.status_code}: {resp.text}"
)


def process_message(_conn: SnowflakeConnection, prompt: str) -> None:
"""Processes a message and adds the response to the chat."""
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
Expand All @@ -139,14 +76,21 @@ def process_message(_conn: SnowflakeConnection, prompt: str) -> None:
if st.session_state.multiturn
else [user_message]
)
response = send_message(_conn=_conn, messages=request_messages)
content = response["message"]["content"]
# Grab the request ID from the response and stash it in the chat message object.
request_id = response["request_id"]
display_content(conn=_conn, content=content, request_id=request_id)
st.session_state.messages.append(
{"role": "analyst", "content": content, "request_id": request_id}
)
try:
response = send_message(
_conn=_conn,
semantic_model=proto_to_yaml(st.session_state.semantic_model),
messages=request_messages,
)
content = response["message"]["content"]
# Grab the request ID from the response and stash it in the chat message object.
request_id = response["request_id"]
display_content(conn=_conn, content=content, request_id=request_id)
st.session_state.messages.append(
{"role": "analyst", "content": content, "request_id": request_id}
)
except ValueError as e:
st.error(e)


def show_expr_for_ref(message_index: int) -> None:
Expand Down
Loading

0 comments on commit 2f1f675

Please sign in to comment.