Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Snowflake agent #1799

Merged
merged 17 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
elif state in ["done", "succeeded"]:
elif state in ["done", "succeeded", "success"]:

Check warning on line 141 in flytekit/extend/backend/base_agent.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extend/backend/base_agent.py#L141

Added line #L141 was not covered by tests
return SUCCEEDED
elif state in ["running"]:
return RUNNING
Expand Down
1 change: 1 addition & 0 deletions plugins/flytekit-snowflake/dev-requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest-asyncio
20 changes: 20 additions & 0 deletions plugins/flytekit-snowflake/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile dev-requirements.in
#
exceptiongroup==1.1.3
# via pytest
iniconfig==2.0.0
# via pytest
packaging==23.1
# via pytest
pluggy==1.3.0
# via pytest
pytest==7.4.0
# via pytest-asyncio
pytest-asyncio==0.21.1
# via -r dev-requirements.in
tomli==2.0.1
# via pytest
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
SnowflakeTask
"""

from .agent import SnowflakeAgent
from .task import SnowflakeConfig, SnowflakeTask
155 changes: 155 additions & 0 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional

import grpc
import snowflake.connector
from flyteidl.admin.agent_pb2 import (
PERMANENT_FAILURE,
SUCCEEDED,
CreateTaskResponse,
DeleteTaskResponse,
GetTaskResponse,
Resource,
)
from snowflake.connector import ProgrammingError

from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state
from flytekit.models import literals
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.models.types import LiteralType, StructuredDatasetType

TASK_TYPE = "snowflake"


@dataclass
class Metadata:
user: str
account: str
database: str
schema: str
warehouse: str
table: str
query_id: str


class SnowflakeAgent(AgentBase):
def __init__(self):
super().__init__(task_type=TASK_TYPE)

def get_private_key(self):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

import flytekit

pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb")
p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend())

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

return pkb

def get_connection(self, metadata: Metadata) -> snowflake.connector:
return snowflake.connector.connect(
user=metadata.user,
account=metadata.account,
private_key=self.get_private_key(),
database=metadata.database,
schema=metadata.schema,
warehouse=metadata.warehouse,
)

async def async_create(
self,
context: grpc.ServicerContext,
output_prefix: str,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
) -> CreateTaskResponse:
params = None
if inputs:
ctx = FlyteContextManager.current_context()
python_interface_inputs = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items()
}
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)
logger.info(f"Create Snowflake params with inputs: {native_inputs}")
params = native_inputs

config = task_template.config

conn = snowflake.connector.connect(
user=config["user"],
account=config["account"],
private_key=self.get_private_key(),
database=config["database"],
schema=config["schema"],
warehouse=config["warehouse"],
)

cs = conn.cursor()
cs.execute_async(task_template.sql.statement, params=params)

metadata = Metadata(
user=config["user"],
account=config["account"],
database=config["database"],
schema=config["schema"],
warehouse=config["warehouse"],
table=config["table"],
query_id=str(cs.sfqid),
)

return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8"))

async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
conn = self.get_connection(metadata)
try:
query_status = conn.get_query_status_throw_if_error(metadata.query_id)
except ProgrammingError as err:
logger.error(err.msg)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(err.msg)
return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE))
cur_state = convert_to_flyte_state(str(query_status.name))
res = None

if cur_state == SUCCEEDED:
ctx = FlyteContextManager.current_context()
output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.database}/{metadata.schema}/{metadata.warehouse}/{metadata.table}"
res = literals.LiteralMap(
{
"results": TypeEngine.to_literal(
ctx,
StructuredDataset(uri=output_metadata),
StructuredDataset,
LiteralType(structured_dataset_type=StructuredDatasetType(format="")),
)
}
).to_flyte_idl()

return GetTaskResponse(resource=Resource(state=cur_state, outputs=res))

async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse:
metadata = Metadata(**json.loads(resource_meta.decode("utf-8")))
conn = self.get_connection(metadata)
cs = conn.cursor()
try:
cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')")
cs.fetchall()
finally:
cs.close()
conn.close()
return DeleteTaskResponse()


AgentRegistry.register(SnowflakeAgent())
15 changes: 12 additions & 3 deletions plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

from flytekit.configuration import SerializationSettings
from flytekit.extend import SQLTask
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.models import task as _task_model
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured import StructuredDataset

_USER_FIELD = "user"
_ACCOUNT_FIELD = "account"
_DATABASE_FIELD = "database"
_SCHEMA_FIELD = "schema"
_WAREHOUSE_FIELD = "warehouse"
_TABLE_FIELD = "table"


@dataclass
Expand All @@ -18,6 +21,8 @@ class SnowflakeConfig(object):
SnowflakeConfig should be used to configure a Snowflake Task.
"""

# The user to query against
user: Optional[str] = None
# The account to query against
account: Optional[str] = None
# The database to query against
Expand All @@ -26,9 +31,11 @@ class SnowflakeConfig(object):
schema: Optional[str] = None
# The optional warehouse to set for the given Snowflake query
warehouse: Optional[str] = None
# The optional table to set for the given Snowflake query
table: Optional[str] = None


class SnowflakeTask(SQLTask[SnowflakeConfig]):
class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]):
"""
This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output.
"""
Expand All @@ -42,7 +49,7 @@ def __init__(
query_template: str,
task_config: Optional[SnowflakeConfig] = None,
inputs: Optional[Dict[str, Type]] = None,
output_schema_type: Optional[Type[FlyteSchema]] = None,
output_schema_type: Optional[Type[StructuredDataset]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -76,10 +83,12 @@ def __init__(

def get_config(self, settings: SerializationSettings) -> Dict[str, str]:
return {
_USER_FIELD: self.task_config.user,
_ACCOUNT_FIELD: self.task_config.account,
_DATABASE_FIELD: self.task_config.database,
_SCHEMA_FIELD: self.task_config.schema,
_WAREHOUSE_FIELD: self.task_config.warehouse,
_TABLE_FIELD: self.task_config.table,
}

def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
Expand Down
3 changes: 2 additions & 1 deletion plugins/flytekit-snowflake/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"]
plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"]

__version__ = "0.0.0+develop"

Expand Down Expand Up @@ -32,4 +32,5 @@
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
120 changes: 120 additions & 0 deletions plugins/flytekit-snowflake/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import json
from dataclasses import asdict
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest
from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse
from flytekitplugins.snowflake.agent import Metadata

import flytekit.models.interface as interface_models
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import Sql, TaskTemplate


@mock.patch("flytekitplugins.snowflake.agent.SnowflakeAgent.get_private_key", return_value="pb")
@mock.patch("snowflake.connector.connect")
@pytest.mark.asyncio
async def test_snowflake_agent(mock_conn, mock_get_private_key):
query_status_mock = MagicMock()
query_status_mock.name = "SUCCEEDED"

# Configure the mock connection to return the mock status object
mock_conn_instance = mock_conn.return_value
mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock

ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent("snowflake")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
)

task_metadata = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
)

task_config = {
"user": "dummy_user",
"account": "dummy_account",
"database": "dummy_database",
"schema": "dummy_schema",
"warehouse": "dummy_warehouse",
"table": "dummy_table",
}

int_type = types.LiteralType(types.SimpleType.INTEGER)
interfaces = interface_models.TypedInterface(
{
"a": interface_models.Variable(int_type, "description1"),
"b": interface_models.Variable(int_type, "description2"),
},
{},
)
task_inputs = literals.LiteralMap(
{
"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
},
)

dummy_template = TaskTemplate(
id=task_id,
custom=None,
config=task_config,
metadata=task_metadata,
interface=interfaces,
type="snowflake",
sql=Sql("SELECT 1"),
)

metadata = Metadata(
user="dummy_user",
account="dummy_account",
table="dummy_table",
database="dummy_database",
schema="dummy_schema",
warehouse="dummy_warehouse",
query_id="dummy_query_id",
)

res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs)
metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id
metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8")
assert res.resource_meta == metadata_bytes

res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert (
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table"
)

delete_response = await agent.async_delete(ctx, metadata_bytes)

# Assert the response
assert isinstance(delete_response, DeleteTaskResponse)

# Verify that the expected methods were called on the mock cursor
mock_cursor = mock_conn_instance.cursor.return_value
mock_cursor.fetchall.assert_called_once()

mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')")
mock_cursor.fetchall.assert_called_once()

# Verify that the connection was closed
mock_cursor.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
Loading