Skip to content

Commit

Permalink
Add the snowflake agent unit-test and fix snowflake-task unit-test
Browse files Browse the repository at this point in the history
Signed-off-by: HH <hhcs9527@gmail.com>
  • Loading branch information
hhcs9527 committed Aug 29, 2023
1 parent 4b69fdf commit 5d718b5
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 4 deletions.
121 changes: 121 additions & 0 deletions plugins/flytekit-snowflake/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import re
from dataclasses import asdict
from datetime import timedelta
from unittest import mock
from unittest.mock import MagicMock

import grpc
import pytest
from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse
from flytekitplugins.snowflake.agent import Metadata
from flytekitplugins.snowflake.task import SnowflakeConfig

import flytekit.models.interface as interface_models
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.interfaces.cli_identifiers import Identifier
from flytekit.models import literals, task, types
from flytekit.models.core.identifier import ResourceType
from flytekit.models.task import Sql, TaskTemplate


@mock.patch("snowflake.connector.connect")
@pytest.mark.asyncio
async def test_snowflake_agent(mock_conn):
query_status_mock = MagicMock()
query_status_mock.name = "SUCCEEDED"

# Configure the mock connection to return the mock status object
mock_conn_instance = mock_conn.return_value
mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock

ctx = MagicMock(spec=grpc.ServicerContext)
agent = AgentRegistry.get_agent(ctx, "snowflake")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
)

task_metadata = task.TaskMetadata(
True,
task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
literals.RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
)

task_config = SnowflakeConfig(
user="dummy_user",
account="dummy_account",
database="dummy_database",
schema="dummy_schema",
warehouse="dummy_warehouse",
table="dummy_table",
)

task_config = {
"user" : "dummy_user",
"account" : "dummy_account",
"database" : "dummy_database",
"schema" : "dummy_schema",
"warehouse" : "dummy_warehouse",
"table" : "dummy_table",
}

int_type = types.LiteralType(types.SimpleType.INTEGER)
interfaces = interface_models.TypedInterface(
{
"a": interface_models.Variable(int_type, "description1"),
"b": interface_models.Variable(int_type, "description2"),
},
{},
)
task_inputs = literals.LiteralMap(
{
"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))),
},
)

dummy_template = TaskTemplate(
id=task_id,
custom=task_config,
metadata=task_metadata,
interface=interfaces,
type="snowflake",
sql=Sql("SELECT 1"),
)

metadata = Metadata(user="dummy_user",account="dummy_account",table="dummy_table",database="dummy_database",schema="dummy_schema",warehouse="dummy_warehouse",query_id="dummy_query_id")

res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs)
metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id
metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8")
assert res.resource_meta == metadata_bytes

res = await agent.async_get(ctx, metadata_bytes)
assert res.resource.state == SUCCEEDED
assert (
res.resource.outputs.literals["results"].scalar.structured_dataset.uri
== "snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table"
)

delete_response = await agent.async_delete(ctx, metadata_bytes)

# Assert the response
assert isinstance(delete_response, DeleteTaskResponse)

# Verify that the expected methods were called on the mock cursor
mock_cursor = mock_conn_instance.cursor.return_value
mock_cursor.fetchall.assert_called_once()

mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')")
mock_cursor.fetchall.assert_called_once()

# Verify that the connection was closed
mock_cursor.close.assert_called_once()
mock_conn_instance.close.assert_called_once()
8 changes: 4 additions & 4 deletions plugins/flytekit-snowflake/tests/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def my_wf(ds: str) -> FlyteSchema:
assert "{{ .rawOutputDataPrefix" in task_spec.template.sql.statement
assert "insert overwrite directory" in task_spec.template.sql.statement
assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
assert "snowflake" == task_spec.template.config["account"]
assert "my_warehouse" == task_spec.template.config["warehouse"]
assert "my_schema" == task_spec.template.config["schema"]
assert "my_database" == task_spec.template.config["database"]
assert "snowflake" == task_spec.template.custom["account"]
assert "my_warehouse" == task_spec.template.custom["warehouse"]
assert "my_schema" == task_spec.template.custom["schema"]
assert "my_database" == task_spec.template.custom["database"]
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

Expand Down

0 comments on commit 5d718b5

Please sign in to comment.