From 602bf8c43f7afb03672ca6924c7ca43d04172709 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Fri, 18 Aug 2023 01:18:03 +0000 Subject: [PATCH 01/17] add AsyncAgentExecutorMixin in task Signed-off-by: hhcs9527 --- .../flytekitplugins/snowflake/task.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 534acb978e..4a8a1496e9 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,9 +3,12 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model -from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset +_USER_FIELD = "user" +_PASSWORD_FIELD = "password" _ACCOUNT_FIELD = "account" _DATABASE_FIELD = "database" _SCHEMA_FIELD = "schema" @@ -18,6 +21,10 @@ class SnowflakeConfig(object): SnowflakeConfig should be used to configure a Snowflake Task. """ + # The user to query against + user: Optional[str] = None + # The password to query against + password: Optional[str] = None # The account to query against account: Optional[str] = None # The database to query against @@ -28,7 +35,7 @@ class SnowflakeConfig(object): warehouse: Optional[str] = None -class SnowflakeTask(SQLTask[SnowflakeConfig]): +class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): """ This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output. """ @@ -42,7 +49,7 @@ def __init__( query_template: str, task_config: Optional[SnowflakeConfig] = None, inputs: Optional[Dict[str, Type]] = None, - output_schema_type: Optional[Type[FlyteSchema]] = None, + output_schema_type: Optional[Type[StructuredDataset]] = None, **kwargs, ): """ @@ -76,6 +83,8 @@ def __init__( def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return { + _USER_FIELD: self.task_config.user, + _PASSWORD_FIELD: self.task_config.password, _ACCOUNT_FIELD: self.task_config.account, _DATABASE_FIELD: self.task_config.database, _SCHEMA_FIELD: self.task_config.schema, From fa1c21ce0a0394ef307b0a4cc63d21444b433c80 Mon Sep 17 00:00:00 2001 From: hhcs9527 Date: Fri, 18 Aug 2023 01:18:48 +0000 Subject: [PATCH 02/17] support snowflake agent Signed-off-by: hhcs9527 --- flytekit/extend/backend/base_agent.py | 2 +- .../flytekitplugins/snowflake/__init__.py | 1 + .../flytekitplugins/snowflake/agent.py | 120 ++++++++++++++++++ 3 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 1bf34c029a..7bfcf4b6b8 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -138,7 +138,7 @@ def convert_to_flyte_state(state: str) -> State: state = state.lower() if state in ["failed"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 2875e56bdf..929bed9486 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -12,3 +12,4 @@ """ from .task import SnowflakeConfig, SnowflakeTask +from .agent import SnowflakeAgent \ No newline at end of file diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py new file mode 100644 index 0000000000..86c6de2925 --- /dev/null +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -0,0 +1,120 @@ +import datetime +import json +from dataclasses import asdict, dataclass +from typing import Dict, Optional + +import snowflake.connector +from snowflake.connector import ProgrammingError + +import grpc +from flyteidl.admin.agent_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + +pythonTypeToBigQueryType: Dict[type, str] = { + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes + list: "ARRAY", + bool: "BOOL", + bytes: "BYTES", + datetime.datetime: "DATETIME", + float: "FLOAT64", + int: "INT64", + str: "STRING", +} + + +@dataclass +class Metadata: + query_id: int + + +class SnowflakeAgent(AgentBase): + def __init__(self): + super().__init__(task_type="snowflake") + + def create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + params = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + logger.info(f"Create Snowflake params with inputs: {native_inputs}") + params = native_inputs + + config = task_template.config + self.conn = snowflake.connector.connect( + user=config["user"], + password=config["password"], + account=config["account"], + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"] + ) + + self.cs = self.conn.cursor() + self.cs.execute_async(task_template.sql.statement, params=params) + metadata = Metadata(query_id=self.cs.sfqid) + + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + try: + query_status = self.conn.get_query_status_throw_if_error(metadata.query_id) + except ProgrammingError as err: + logger.error(err.msg) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(err.msg) + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) + cur_state = convert_to_flyte_state(str(query_status.name)) + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + self.cs.get_results_from_sfqid(metadata.query_id) + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(dataframe=self.cs.fetch_pandas_all()), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ).to_flyte_idl() + print(res) + + return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + try: + self.cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + self.cs.fetchall() + finally: + self.cs.close() + return DeleteTaskResponse() + + +AgentRegistry.register(SnowflakeAgent()) From 40d504e789f06673e61e6e1860cecf42e663d5f0 Mon Sep 17 00:00:00 2001 From: HH Date: Tue, 29 Aug 2023 20:41:20 +0800 Subject: [PATCH 03/17] Move the snowflake connector setup to the right place. Signed-off-by: HH --- .../flytekit-snowflake/flytekitplugins/snowflake/__init__.py | 4 ++-- plugins/flytekit-snowflake/setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 929bed9486..4720159219 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -11,5 +11,5 @@ SnowflakeTask """ -from .task import SnowflakeConfig, SnowflakeTask -from .agent import SnowflakeAgent \ No newline at end of file +from .agent import SnowflakeAgent +from .task import SnowflakeConfig, SnowflakeTask \ No newline at end of file diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 219468b380..8be87abfb5 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"] __version__ = "0.0.0+develop" From 4b69fdf2f899ce8dabbbdbc62544df10cc744eb6 Mon Sep 17 00:00:00 2001 From: HH Date: Tue, 29 Aug 2023 20:43:12 +0800 Subject: [PATCH 04/17] replace password auth with key-pair auth Signed-off-by: HH --- .../flytekitplugins/snowflake/agent.py | 113 ++++++++++++------ .../flytekitplugins/snowflake/task.py | 12 +- 2 files changed, 81 insertions(+), 44 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 86c6de2925..07c45021ca 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -1,12 +1,9 @@ -import datetime import json from dataclasses import asdict, dataclass -from typing import Dict, Optional - -import snowflake.connector -from snowflake.connector import ProgrammingError +from typing import Optional import grpc +import snowflake.connector from flyteidl.admin.agent_pb2 import ( PERMANENT_FAILURE, SUCCEEDED, @@ -15,6 +12,7 @@ GetTaskResponse, Resource, ) +from snowflake.connector import ProgrammingError from flytekit import FlyteContextManager, StructuredDataset, logger from flytekit.core.type_engine import TypeEngine @@ -24,28 +22,54 @@ from flytekit.models.task import TaskTemplate from flytekit.models.types import LiteralType, StructuredDatasetType -pythonTypeToBigQueryType: Dict[type, str] = { - # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_sizes - list: "ARRAY", - bool: "BOOL", - bytes: "BYTES", - datetime.datetime: "DATETIME", - float: "FLOAT64", - int: "INT64", - str: "STRING", -} +TASK_TYPE = "snowflake" @dataclass class Metadata: - query_id: int + user: str + account: str + database: str + schema: str + warehouse: str + table: str + query_id: str class SnowflakeAgent(AgentBase): def __init__(self): - super().__init__(task_type="snowflake") + super().__init__(task_type=TASK_TYPE) + + def get_private_key(self): + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + import flytekit + + pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "rsa_key.p8") - def create( + with open(pk_path, "rb") as key: + p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb + + def get_connection(self, metadata: Metadata) -> snowflake.connector: + return snowflake.connector.connect( + user=metadata.user, + account=metadata.account, + private_key=self.get_private_key(), + database=metadata.database, + schema=metadata.schema, + warehouse=metadata.warehouse, + ) + + async def async_create( self, context: grpc.ServicerContext, output_prefix: str, @@ -62,26 +86,37 @@ def create( logger.info(f"Create Snowflake params with inputs: {native_inputs}") params = native_inputs - config = task_template.config - self.conn = snowflake.connector.connect( - user=config["user"], - password=config["password"], - account=config["account"], - database=config["database"], - schema=config["schema"], - warehouse=config["warehouse"] + custom = task_template.custom + + conn = snowflake.connector.connect( + user=custom["user"], + account=custom["account"], + private_key=self.get_private_key(), + database=custom["database"], + schema=custom["schema"], + warehouse=custom["warehouse"], ) - self.cs = self.conn.cursor() - self.cs.execute_async(task_template.sql.statement, params=params) - metadata = Metadata(query_id=self.cs.sfqid) + cs = conn.cursor() + cs.execute_async(task_template.sql.statement, params=params) + + metadata = Metadata( + user=custom["user"], + account=custom["account"], + database=custom["database"], + schema=custom["schema"], + warehouse=custom["warehouse"], + table=custom["table"], + query_id=str(cs.sfqid), + ) return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) try: - query_status = self.conn.get_query_status_throw_if_error(metadata.query_id) + query_status = conn.get_query_status_throw_if_error(metadata.query_id) except ProgrammingError as err: logger.error(err.msg) context.set_code(grpc.StatusCode.INTERNAL) @@ -92,28 +127,30 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes if cur_state == SUCCEEDED: ctx = FlyteContextManager.current_context() - self.cs.get_results_from_sfqid(metadata.query_id) + output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.database}/{metadata.schema}/{metadata.warehouse}/{metadata.table}" res = literals.LiteralMap( { "results": TypeEngine.to_literal( ctx, - StructuredDataset(dataframe=self.cs.fetch_pandas_all()), + StructuredDataset(uri=output_metadata), StructuredDataset, LiteralType(structured_dataset_type=StructuredDatasetType(format="")), ) } ).to_flyte_idl() - print(res) return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) + cs = conn.cursor() try: - self.cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") - self.cs.fetchall() + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.fetchall() finally: - self.cs.close() + cs.close() + conn.close() return DeleteTaskResponse() diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 4a8a1496e9..425d03beb5 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -8,11 +8,11 @@ from flytekit.types.structured import StructuredDataset _USER_FIELD = "user" -_PASSWORD_FIELD = "password" _ACCOUNT_FIELD = "account" _DATABASE_FIELD = "database" _SCHEMA_FIELD = "schema" _WAREHOUSE_FIELD = "warehouse" +_TABLE_FIELD = "table" @dataclass @@ -23,9 +23,7 @@ class SnowflakeConfig(object): # The user to query against user: Optional[str] = None - # The password to query against - password: Optional[str] = None - # The account to query against + # The account to query againstk account: Optional[str] = None # The database to query against database: Optional[str] = None @@ -33,6 +31,8 @@ class SnowflakeConfig(object): schema: Optional[str] = None # The optional warehouse to set for the given Snowflake query warehouse: Optional[str] = None + # The optional table to set for the given Snowflake query + table: Optional[str] = None class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): @@ -81,14 +81,14 @@ def __init__( ) self._output_schema_type = output_schema_type - def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + def get_custom(self, settings: SerializationSettings) -> Dict[str, str]: return { _USER_FIELD: self.task_config.user, - _PASSWORD_FIELD: self.task_config.password, _ACCOUNT_FIELD: self.task_config.account, _DATABASE_FIELD: self.task_config.database, _SCHEMA_FIELD: self.task_config.schema, _WAREHOUSE_FIELD: self.task_config.warehouse, + _TABLE_FIELD: self.task_config.table, } def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: From 5d718b5c1ac200e0ddbe476ad656d7737a6bc06d Mon Sep 17 00:00:00 2001 From: HH Date: Tue, 29 Aug 2023 20:43:54 +0800 Subject: [PATCH 05/17] Add the snowflake agent unit-test and fix snowflake-task unit-test Signed-off-by: HH --- .../flytekit-snowflake/tests/test_agent.py | 121 ++++++++++++++++++ .../tests/test_snowflake.py | 8 +- 2 files changed, 125 insertions(+), 4 deletions(-) create mode 100644 plugins/flytekit-snowflake/tests/test_agent.py diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py new file mode 100644 index 0000000000..392fa61067 --- /dev/null +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -0,0 +1,121 @@ +import json +import re +from dataclasses import asdict +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flytekitplugins.snowflake.agent import Metadata +from flytekitplugins.snowflake.task import SnowflakeConfig + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Sql, TaskTemplate + + +@mock.patch("snowflake.connector.connect") +@pytest.mark.asyncio +async def test_snowflake_agent(mock_conn): + query_status_mock = MagicMock() + query_status_mock.name = "SUCCEEDED" + + # Configure the mock connection to return the mock status object + mock_conn_instance = mock_conn.return_value + mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock + + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent(ctx, "snowflake") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + task_config = SnowflakeConfig( + user="dummy_user", + account="dummy_account", + database="dummy_database", + schema="dummy_schema", + warehouse="dummy_warehouse", + table="dummy_table", + ) + + task_config = { + "user" : "dummy_user", + "account" : "dummy_account", + "database" : "dummy_database", + "schema" : "dummy_schema", + "warehouse" : "dummy_warehouse", + "table" : "dummy_table", + } + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + interface=interfaces, + type="snowflake", + sql=Sql("SELECT 1"), + ) + + metadata = Metadata(user="dummy_user",account="dummy_account",table="dummy_table",database="dummy_database",schema="dummy_schema",warehouse="dummy_warehouse",query_id="dummy_query_id") + + res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs) + metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id + metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") + assert res.resource_meta == metadata_bytes + + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs.literals["results"].scalar.structured_dataset.uri + == "snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table" + ) + + delete_response = await agent.async_delete(ctx, metadata_bytes) + + # Assert the response + assert isinstance(delete_response, DeleteTaskResponse) + + # Verify that the expected methods were called on the mock cursor + mock_cursor = mock_conn_instance.cursor.return_value + mock_cursor.fetchall.assert_called_once() + + mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + mock_cursor.fetchall.assert_called_once() + + # Verify that the connection was closed + mock_cursor.close.assert_called_once() + mock_conn_instance.close.assert_called_once() \ No newline at end of file diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index 672f4a19ad..127a9b762e 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -46,10 +46,10 @@ def my_wf(ds: str) -> FlyteSchema: assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI - assert "snowflake" == task_spec.template.config["account"] - assert "my_warehouse" == task_spec.template.config["warehouse"] - assert "my_schema" == task_spec.template.config["schema"] - assert "my_database" == task_spec.template.config["database"] + assert "snowflake" == task_spec.template.custom["account"] + assert "my_warehouse" == task_spec.template.custom["warehouse"] + assert "my_schema" == task_spec.template.custom["schema"] + assert "my_database" == task_spec.template.custom["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 From 920263aea0703861761cfc77051aeafaa8b5a3d4 Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 30 Aug 2023 17:32:43 +0800 Subject: [PATCH 06/17] roll back the custom to config Signed-off-by: HH --- .../flytekitplugins/snowflake/__init__.py | 2 +- .../flytekitplugins/snowflake/agent.py | 24 +++++++++---------- .../flytekitplugins/snowflake/task.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 4720159219..9c16d5398e 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -12,4 +12,4 @@ """ from .agent import SnowflakeAgent -from .task import SnowflakeConfig, SnowflakeTask \ No newline at end of file +from .task import SnowflakeConfig, SnowflakeTask diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 07c45021ca..168b28b744 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -86,27 +86,27 @@ async def async_create( logger.info(f"Create Snowflake params with inputs: {native_inputs}") params = native_inputs - custom = task_template.custom + config = task_template.config conn = snowflake.connector.connect( - user=custom["user"], - account=custom["account"], + user=config.user, + account=config.account, private_key=self.get_private_key(), - database=custom["database"], - schema=custom["schema"], - warehouse=custom["warehouse"], + database=config.database, + schema=config.schema, + warehouse=config.warehouse, ) cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) metadata = Metadata( - user=custom["user"], - account=custom["account"], - database=custom["database"], - schema=custom["schema"], - warehouse=custom["warehouse"], - table=custom["table"], + user=config.user, + account=config.account, + database=config.database, + schema=config.schema, + warehouse=config.warehouse, + table=config.table, query_id=str(cs.sfqid), ) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 425d03beb5..4517b15a9b 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -81,7 +81,7 @@ def __init__( ) self._output_schema_type = output_schema_type - def get_custom(self, settings: SerializationSettings) -> Dict[str, str]: + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return { _USER_FIELD: self.task_config.user, _ACCOUNT_FIELD: self.task_config.account, From 6d867fa4865b590bded8da01ee1cdca29957be9c Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 30 Aug 2023 17:33:08 +0800 Subject: [PATCH 07/17] fix the unit-test Signed-off-by: HH --- .../flytekit-snowflake/tests/test_agent.py | 25 +++++++++---------- .../tests/test_snowflake.py | 8 +++--- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 392fa61067..139a4580ee 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -1,5 +1,4 @@ import json -import re from dataclasses import asdict from datetime import timedelta from unittest import mock @@ -57,15 +56,6 @@ async def test_snowflake_agent(mock_conn): table="dummy_table", ) - task_config = { - "user" : "dummy_user", - "account" : "dummy_account", - "database" : "dummy_database", - "schema" : "dummy_schema", - "warehouse" : "dummy_warehouse", - "table" : "dummy_table", - } - int_type = types.LiteralType(types.SimpleType.INTEGER) interfaces = interface_models.TypedInterface( { @@ -83,14 +73,23 @@ async def test_snowflake_agent(mock_conn): dummy_template = TaskTemplate( id=task_id, - custom=task_config, + custom=None, + config=task_config, metadata=task_metadata, interface=interfaces, type="snowflake", sql=Sql("SELECT 1"), ) - metadata = Metadata(user="dummy_user",account="dummy_account",table="dummy_table",database="dummy_database",schema="dummy_schema",warehouse="dummy_warehouse",query_id="dummy_query_id") + metadata = Metadata( + user="dummy_user", + account="dummy_account", + table="dummy_table", + database="dummy_database", + schema="dummy_schema", + warehouse="dummy_warehouse", + query_id="dummy_query_id", + ) res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs) metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id @@ -118,4 +117,4 @@ async def test_snowflake_agent(mock_conn): # Verify that the connection was closed mock_cursor.close.assert_called_once() - mock_conn_instance.close.assert_called_once() \ No newline at end of file + mock_conn_instance.close.assert_called_once() diff --git a/plugins/flytekit-snowflake/tests/test_snowflake.py b/plugins/flytekit-snowflake/tests/test_snowflake.py index 127a9b762e..672f4a19ad 100644 --- a/plugins/flytekit-snowflake/tests/test_snowflake.py +++ b/plugins/flytekit-snowflake/tests/test_snowflake.py @@ -46,10 +46,10 @@ def my_wf(ds: str) -> FlyteSchema: assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement assert "insert overwrite directory" in task_spec.template.sql.statement assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI - assert "snowflake" == task_spec.template.custom["account"] - assert "my_warehouse" == task_spec.template.custom["warehouse"] - assert "my_schema" == task_spec.template.custom["schema"] - assert "my_database" == task_spec.template.custom["database"] + assert "snowflake" == task_spec.template.config["account"] + assert "my_warehouse" == task_spec.template.config["warehouse"] + assert "my_schema" == task_spec.template.config["schema"] + assert "my_database" == task_spec.template.config["database"] assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 From 8ecffcbc85bb32a446db624a36a61b283775bccc Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 09:57:10 +0800 Subject: [PATCH 08/17] add dev-requirement.in and re-run pip compile Signed-off-by: HH --- .../flytekit-snowflake/dev-requirements.in | 1 + .../flytekit-snowflake/dev-requirements.txt | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) create mode 100644 plugins/flytekit-snowflake/dev-requirements.in create mode 100644 plugins/flytekit-snowflake/dev-requirements.txt diff --git a/plugins/flytekit-snowflake/dev-requirements.in b/plugins/flytekit-snowflake/dev-requirements.in new file mode 100644 index 0000000000..990d54ff0c --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.in @@ -0,0 +1 @@ +pytest-asyncio \ No newline at end of file diff --git a/plugins/flytekit-snowflake/dev-requirements.txt b/plugins/flytekit-snowflake/dev-requirements.txt new file mode 100644 index 0000000000..99d3f5e4e9 --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.txt @@ -0,0 +1,20 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +exceptiongroup==1.1.3 + # via pytest +iniconfig==2.0.0 + # via pytest +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.0 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest From 55fe7e6d67241cb5d46fd017daa16faf326d4f4e Mon Sep 17 00:00:00 2001 From: HH Date: Thu, 31 Aug 2023 17:35:48 +0800 Subject: [PATCH 09/17] add entry_points in setup for loading the modules Signed-off-by: HH --- plugins/flytekit-snowflake/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 8be87abfb5..527daa2486 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -32,4 +32,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) From e9e0741ba744a97f3671b4777427e23a245cec62 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 1 Sep 2023 12:09:46 +0800 Subject: [PATCH 10/17] mock private key for snowflake-agent test in git Action Signed-off-by: HH --- plugins/flytekit-snowflake/tests/test_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 139a4580ee..030d6a26e9 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -18,9 +18,10 @@ from flytekit.models.task import Sql, TaskTemplate +@mock.patch("flytekitplugins.snowflake.agent.SnowflakeAgent.get_private_key", return_value="pb") @mock.patch("snowflake.connector.connect") @pytest.mark.asyncio -async def test_snowflake_agent(mock_conn): +async def test_snowflake_agent(mock_conn, mock_get_private_key): query_status_mock = MagicMock() query_status_mock.name = "SUCCEEDED" From fc6de28d4343fe60d840e507c609ebb4fe69e80e Mon Sep 17 00:00:00 2001 From: HH Date: Mon, 4 Sep 2023 10:48:02 +0800 Subject: [PATCH 11/17] fix tasktemplate.config issue in agent.py and its unit-test Signed-off-by: HH --- .../flytekitplugins/snowflake/agent.py | 22 +++++++++---------- .../flytekit-snowflake/tests/test_agent.py | 16 +++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 168b28b744..ede4f76dff 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -89,24 +89,24 @@ async def async_create( config = task_template.config conn = snowflake.connector.connect( - user=config.user, - account=config.account, + user=config["user"], + account=config["account"], private_key=self.get_private_key(), - database=config.database, - schema=config.schema, - warehouse=config.warehouse, + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], ) cs = conn.cursor() cs.execute_async(task_template.sql.statement, params=params) metadata = Metadata( - user=config.user, - account=config.account, - database=config.database, - schema=config.schema, - warehouse=config.warehouse, - table=config.table, + user=config["user"], + account=config["account"], + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], + table=config["table"], query_id=str(cs.sfqid), ) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 030d6a26e9..2a84976e07 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -48,14 +48,14 @@ async def test_snowflake_agent(mock_conn, mock_get_private_key): "A", ) - task_config = SnowflakeConfig( - user="dummy_user", - account="dummy_account", - database="dummy_database", - schema="dummy_schema", - warehouse="dummy_warehouse", - table="dummy_table", - ) + task_config = { + "user":"dummy_user", + "account":"dummy_account", + "database":"dummy_database", + "schema":"dummy_schema", + "warehouse":"dummy_warehouse", + "table":"dummy_table" + } int_type = types.LiteralType(types.SimpleType.INTEGER) interfaces = interface_models.TypedInterface( From 50f012b49a40232d42fded766a0298c02db8ed5f Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 6 Sep 2023 16:19:58 +0800 Subject: [PATCH 12/17] fix lint in test-agent Signed-off-by: HH --- plugins/flytekit-snowflake/tests/test_agent.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index 2a84976e07..ebd0b425a5 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -8,7 +8,6 @@ import pytest from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse from flytekitplugins.snowflake.agent import Metadata -from flytekitplugins.snowflake.task import SnowflakeConfig import flytekit.models.interface as interface_models from flytekit.extend.backend.base_agent import AgentRegistry @@ -49,12 +48,12 @@ async def test_snowflake_agent(mock_conn, mock_get_private_key): ) task_config = { - "user":"dummy_user", - "account":"dummy_account", - "database":"dummy_database", - "schema":"dummy_schema", - "warehouse":"dummy_warehouse", - "table":"dummy_table" + "user": "dummy_user", + "account": "dummy_account", + "database": "dummy_database", + "schema": "dummy_schema", + "warehouse": "dummy_warehouse", + "table": "dummy_table", } int_type = types.LiteralType(types.SimpleType.INTEGER) From 5a103d29218801b4daa5af1889eeae19cbf59a02 Mon Sep 17 00:00:00 2001 From: HH Date: Fri, 8 Sep 2023 16:23:27 +0800 Subject: [PATCH 13/17] merge latest flytekit Signed-off-by: HH --- plugins/flytekit-snowflake/dev-requirements.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-snowflake/dev-requirements.in b/plugins/flytekit-snowflake/dev-requirements.in index 990d54ff0c..2d73dba5b4 100644 --- a/plugins/flytekit-snowflake/dev-requirements.in +++ b/plugins/flytekit-snowflake/dev-requirements.in @@ -1 +1 @@ -pytest-asyncio \ No newline at end of file +pytest-asyncio From b8cc0f58de964497480c08bb1b76a94969857b6f Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 9 Sep 2023 15:22:47 +0800 Subject: [PATCH 14/17] fix CI error with remove ctx in get_agent function in test Signed-off-by: HH --- .../flytekit-snowflake/flytekitplugins/snowflake/__init__.py | 1 + plugins/flytekit-snowflake/tests/test_agent.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 9c16d5398e..336aa891da 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -13,3 +13,4 @@ from .agent import SnowflakeAgent from .task import SnowflakeConfig, SnowflakeTask +from .agent import SnowflakeAgent \ No newline at end of file diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py index ebd0b425a5..25e2ad118c 100644 --- a/plugins/flytekit-snowflake/tests/test_agent.py +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -29,7 +29,7 @@ async def test_snowflake_agent(mock_conn, mock_get_private_key): mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "snowflake") + agent = AgentRegistry.get_agent("snowflake") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" From 448cbd323ec90aaaa5635bf4c9d8e1ea4c86ba45 Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 9 Sep 2023 17:01:39 +0800 Subject: [PATCH 15/17] replace replace private_key instead of rsa_key.p8 Signed-off-by: HH --- plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py | 3 ++- plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index ede4f76dff..73de306f22 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -46,7 +46,8 @@ def get_private_key(self): import flytekit - pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "rsa_key.p8") + pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "private_key") + # print(pk_path) with open(pk_path, "rb") as key: p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 4517b15a9b..9ac9980a88 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -23,7 +23,7 @@ class SnowflakeConfig(object): # The user to query against user: Optional[str] = None - # The account to query againstk + # The account to query against account: Optional[str] = None # The database to query against database: Optional[str] = None From 6d8931613d0a71ddbe9082c99c3e8039b6a840b9 Mon Sep 17 00:00:00 2001 From: HH Date: Sat, 9 Sep 2023 23:38:37 +0800 Subject: [PATCH 16/17] lint fix Signed-off-by: HH --- plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py | 1 - plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py | 1 - 2 files changed, 2 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 336aa891da..9c16d5398e 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -13,4 +13,3 @@ from .agent import SnowflakeAgent from .task import SnowflakeConfig, SnowflakeTask -from .agent import SnowflakeAgent \ No newline at end of file diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 73de306f22..52c70ad29c 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -47,7 +47,6 @@ def get_private_key(self): import flytekit pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "private_key") - # print(pk_path) with open(pk_path, "rb") as key: p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) From 5851207ddefa88c85967fea81594736932ec079d Mon Sep 17 00:00:00 2001 From: HH Date: Wed, 13 Sep 2023 11:49:18 +0800 Subject: [PATCH 17/17] replace get_secrets_file with get Signed-off-by: HH --- .../flytekit-snowflake/flytekitplugins/snowflake/agent.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py index 52c70ad29c..c4176228ea 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -46,10 +46,8 @@ def get_private_key(self): import flytekit - pk_path = flytekit.current_context().secrets.get_secrets_file(TASK_TYPE, "private_key") - - with open(pk_path, "rb") as key: - p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend()) + pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb") + p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) pkb = p_key.private_bytes( encoding=serialization.Encoding.DER,