diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 1bf34c029a..7bfcf4b6b8 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -138,7 +138,7 @@ def convert_to_flyte_state(state: str) -> State: state = state.lower() if state in ["failed"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING diff --git a/plugins/flytekit-snowflake/dev-requirements.in b/plugins/flytekit-snowflake/dev-requirements.in new file mode 100644 index 0000000000..2d73dba5b4 --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.in @@ -0,0 +1 @@ +pytest-asyncio diff --git a/plugins/flytekit-snowflake/dev-requirements.txt b/plugins/flytekit-snowflake/dev-requirements.txt new file mode 100644 index 0000000000..99d3f5e4e9 --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.txt @@ -0,0 +1,20 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +exceptiongroup==1.1.3 + # via pytest +iniconfig==2.0.0 + # via pytest +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.0 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 2875e56bdf..9c16d5398e 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -11,4 +11,5 @@ SnowflakeTask """ +from .agent import SnowflakeAgent from .task import SnowflakeConfig, SnowflakeTask diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py new file mode 100644 index 0000000000..c4176228ea --- /dev/null +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -0,0 +1,155 @@ +import json +from dataclasses import asdict, dataclass +from typing import Optional + +import grpc +import snowflake.connector +from flyteidl.admin.agent_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) +from snowflake.connector import ProgrammingError + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + +TASK_TYPE = "snowflake" + + +@dataclass +class Metadata: + user: str + account: str + database: str + schema: str + warehouse: str + table: str + query_id: str + + +class SnowflakeAgent(AgentBase): + def __init__(self): + super().__init__(task_type=TASK_TYPE) + + def get_private_key(self): + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + import flytekit + + pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb") + p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb + + def get_connection(self, metadata: Metadata) -> snowflake.connector: + return snowflake.connector.connect( + user=metadata.user, + account=metadata.account, + private_key=self.get_private_key(), + database=metadata.database, + schema=metadata.schema, + warehouse=metadata.warehouse, + ) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + params = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + logger.info(f"Create Snowflake params with inputs: {native_inputs}") + params = native_inputs + + config = task_template.config + + conn = snowflake.connector.connect( + user=config["user"], + account=config["account"], + private_key=self.get_private_key(), + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], + ) + + cs = conn.cursor() + cs.execute_async(task_template.sql.statement, params=params) + + metadata = Metadata( + user=config["user"], + account=config["account"], + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], + table=config["table"], + query_id=str(cs.sfqid), + ) + + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) + try: + query_status = conn.get_query_status_throw_if_error(metadata.query_id) + except ProgrammingError as err: + logger.error(err.msg) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(err.msg) + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) + cur_state = convert_to_flyte_state(str(query_status.name)) + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.database}/{metadata.schema}/{metadata.warehouse}/{metadata.table}" + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=output_metadata), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ).to_flyte_idl() + + return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) + cs = conn.cursor() + try: + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.fetchall() + finally: + cs.close() + conn.close() + return DeleteTaskResponse() + + +AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 534acb978e..9ac9980a88 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,13 +3,16 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model -from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset +_USER_FIELD = "user" _ACCOUNT_FIELD = "account" _DATABASE_FIELD = "database" _SCHEMA_FIELD = "schema" _WAREHOUSE_FIELD = "warehouse" +_TABLE_FIELD = "table" @dataclass @@ -18,6 +21,8 @@ class SnowflakeConfig(object): SnowflakeConfig should be used to configure a Snowflake Task. """ + # The user to query against + user: Optional[str] = None # The account to query against account: Optional[str] = None # The database to query against @@ -26,9 +31,11 @@ class SnowflakeConfig(object): schema: Optional[str] = None # The optional warehouse to set for the given Snowflake query warehouse: Optional[str] = None + # The optional table to set for the given Snowflake query + table: Optional[str] = None -class SnowflakeTask(SQLTask[SnowflakeConfig]): +class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): """ This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output. """ @@ -42,7 +49,7 @@ def __init__( query_template: str, task_config: Optional[SnowflakeConfig] = None, inputs: Optional[Dict[str, Type]] = None, - output_schema_type: Optional[Type[FlyteSchema]] = None, + output_schema_type: Optional[Type[StructuredDataset]] = None, **kwargs, ): """ @@ -76,10 +83,12 @@ def __init__( def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return { + _USER_FIELD: self.task_config.user, _ACCOUNT_FIELD: self.task_config.account, _DATABASE_FIELD: self.task_config.database, _SCHEMA_FIELD: self.task_config.schema, _WAREHOUSE_FIELD: self.task_config.warehouse, + _TABLE_FIELD: self.task_config.table, } def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 219468b380..527daa2486 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"] __version__ = "0.0.0+develop" @@ -32,4 +32,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py new file mode 100644 index 0000000000..25e2ad118c --- /dev/null +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -0,0 +1,120 @@ +import json +from dataclasses import asdict +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flytekitplugins.snowflake.agent import Metadata + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Sql, TaskTemplate + + +@mock.patch("flytekitplugins.snowflake.agent.SnowflakeAgent.get_private_key", return_value="pb") +@mock.patch("snowflake.connector.connect") +@pytest.mark.asyncio +async def test_snowflake_agent(mock_conn, mock_get_private_key): + query_status_mock = MagicMock() + query_status_mock.name = "SUCCEEDED" + + # Configure the mock connection to return the mock status object + mock_conn_instance = mock_conn.return_value + mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock + + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("snowflake") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + task_config = { + "user": "dummy_user", + "account": "dummy_account", + "database": "dummy_database", + "schema": "dummy_schema", + "warehouse": "dummy_warehouse", + "table": "dummy_table", + } + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=None, + config=task_config, + metadata=task_metadata, + interface=interfaces, + type="snowflake", + sql=Sql("SELECT 1"), + ) + + metadata = Metadata( + user="dummy_user", + account="dummy_account", + table="dummy_table", + database="dummy_database", + schema="dummy_schema", + warehouse="dummy_warehouse", + query_id="dummy_query_id", + ) + + res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs) + metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id + metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") + assert res.resource_meta == metadata_bytes + + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs.literals["results"].scalar.structured_dataset.uri + == "snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" + ) + + delete_response = await agent.async_delete(ctx, metadata_bytes) + + # Assert the response + assert isinstance(delete_response, DeleteTaskResponse) + + # Verify that the expected methods were called on the mock cursor + mock_cursor = mock_conn_instance.cursor.return_value + mock_cursor.fetchall.assert_called_once() + + mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + mock_cursor.fetchall.assert_called_once() + + # Verify that the connection was closed + mock_cursor.close.assert_called_once() + mock_conn_instance.close.assert_called_once()