From 8365c530294150ed23958b74d3962ab40bc5e431 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 7 Sep 2023 10:34:10 +0800 Subject: [PATCH 01/64] drafte chatgpt agent Signed-off-by: Future Outlier --- .../flytekitplugins/chatgpt/agent.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py new file mode 100644 index 0000000000..ac738ec0e3 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py @@ -0,0 +1,54 @@ +import json +import pickle +import typing +from dataclasses import dataclass +from typing import Optional + +import aiohttp +import grpc +from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource + +import flytekit +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +class ChatGPTAgent(AgentBase): + def __init__(self): + super().__init__(task_type="chatgpt") + + + async def async_do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> GetTaskResponse: + custom = task_template.custom + chatgpt_job = custom["chatgptConf"] + openai_organization = custom["openaiOrganization"] + + openai_url = "https://api.openai.com/v1/chat/completions" + data = json.dumps(chatgpt_job) + + async with aiohttp.ClientSession() as session: + async with session.post(openai_url, headers=get_header(openai_organization), data=data) as resp: + if resp.status != 200: + raise Exception(f"Failed to execute chathpt job with error: {resp.reason}") + response = await resp.json() + + print("Do Response: ", response) + + return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + + + +def get_header(openai_organization: str): + token = flytekit.current_context().secrets.get("openai", "token") + return { + 'OpenAI-Organization': openai_organization, + 'Authorization': f'Bearer {token}', + 'content-type': 'application/json' + } \ No newline at end of file From 25d6a5d70c3bbe75562d30c41f46da9147f50202 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 7 Sep 2023 10:38:39 +0800 Subject: [PATCH 02/64] others Signed-off-by: Future Outlier --- .../flytekitplugins/chatgpt/__init__.py | 8 +++++ .../flytekitplugins/chatgpt/task.py | 1 + plugins/flytekit-openai-chatgpt/setup.py | 36 +++++++++++++++++++ 3 files changed, 45 insertions(+) create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py create mode 100644 plugins/flytekit-openai-chatgpt/setup.py diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..3973ab8213 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,8 @@ +""" + currentmodule:: flytekitplugins.chatgpt +""" + +from flytekit.configuration import internal as _internal + +from .agent import ChatGPTAgent + diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py new file mode 100644 index 0000000000..112e260dca --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -0,0 +1 @@ +# Can we don't implement it but still use??? \ No newline at end of file diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py new file mode 100644 index 0000000000..5b7b96dcaa --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="chatgpt plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) From 970bf3b0369e0eb356f1e1a3fa88b66df1a8d5ca Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 7 Sep 2023 10:41:55 +0800 Subject: [PATCH 03/64] sync plugin, DoTask function interface Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 25 ++++++++++++++++++++++++ flytekit/extend/backend/base_agent.py | 24 +++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 470bd01e2e..430c7a46f5 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -91,3 +91,28 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon logger.error(f"failed to delete task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to delete task with error {e}") + + async def DoTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: + try: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + agent = AgentRegistry.get_agent(context, tmp.type) + logger.info(f"{agent.task_type} agent start doing the job") + if agent.asynchronous: + try: + return await agent.async_do( + context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + ) + except Exception as e: + logger.error(f"failed to run async do with error {e}") + raise e + try: + return await asyncio.to_thread( + agent.do, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + ) + except Exception as e: + logger + except Exception as e: + logger.error(f"failed to do task with error {e}") + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(f"failed to create task with error {e}") \ No newline at end of file diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 1bf34c029a..45dc983ad8 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -79,6 +79,18 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT """ raise NotImplementedError + def do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + """ + Return the result of executing a task. It should return error code if the task creation failed. + """ + raise NotImplementedError + async def async_create( self, context: grpc.ServicerContext, @@ -105,6 +117,18 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes """ raise NotImplementedError + def async_do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + """ + Return the result of executing a task. It should return error code if the task creation failed. + """ + raise NotImplementedError + class AgentRegistry(object): """ From 6087c5fcc09c3c4686b206104f043e674e70d542 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 11 Sep 2023 16:33:04 +0800 Subject: [PATCH 04/64] requester agent base and chatgpt agent Signed-off-by: Future Outlier --- flytekit/requester/__init__.py | 2 + flytekit/requester/base_requester.py | 53 ++++++++++++++++++ .../requester/chatgpt_requester.py | 36 ++++++++++-- flytekit/requester/requester_engine.py | 56 +++++++++++++++++++ .../flytekitplugins/chatgpt/__init__.py | 8 --- .../flytekitplugins/chatgpt/task.py | 1 - plugins/flytekit-openai-chatgpt/setup.py | 36 ------------ 7 files changed, 142 insertions(+), 50 deletions(-) create mode 100644 flytekit/requester/__init__.py create mode 100644 flytekit/requester/base_requester.py rename plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py => flytekit/requester/chatgpt_requester.py (59%) create mode 100644 flytekit/requester/requester_engine.py delete mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py delete mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py delete mode 100644 plugins/flytekit-openai-chatgpt/setup.py diff --git a/flytekit/requester/__init__.py b/flytekit/requester/__init__.py new file mode 100644 index 0000000000..20760d5990 --- /dev/null +++ b/flytekit/requester/__init__.py @@ -0,0 +1,2 @@ +from .base_requester import BaseRequester +from .chatgpt_requester import ChatGPTRequester \ No newline at end of file diff --git a/flytekit/requester/base_requester.py b/flytekit/requester/base_requester.py new file mode 100644 index 0000000000..d1446e7727 --- /dev/null +++ b/flytekit/requester/base_requester.py @@ -0,0 +1,53 @@ +import collections +import inspect +from abc import abstractmethod +from typing import Any, Dict, Optional, TypeVar + +import jsonpickle +from typing_extensions import get_type_hints +from flyteidl.admin.agent_pb2 import GetTaskResponse +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin + +T = TypeVar("T") +REQUESTER_MODULE = "requester_module" +REQUESTER_NAME = "requester_name" +REQUESTER_CONFIG_PKL = "requester_config_pkl" +INPUTS = "inputs" + + +class BaseRequester(AsyncAgentExecutorMixin, PythonTask): + """ + TODO: Write the docstring + Base class for all requesters. Sensors are tasks that are designed to run forever, and periodically check for some + condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the + sensor agent, and not by the Flyte engine. + """ + + def __init__( + self, + name: str, + task_type: str = "requester", + **kwargs, + ): + + super().__init__( + task_type=task_type, + name=name, + **kwargs, + ) + + @abstractmethod + async def do(self, **kwargs) -> GetTaskResponse: + raise NotImplementedError + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + cfg = { + REQUESTER_MODULE: type(self).__module__, + REQUESTER_NAME: type(self).__name__, + } + if self._requester_config is not None: + cfg[REQUESTER_CONFIG_PKL] = jsonpickle.encode(self._requester_config) + return cfg diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py b/flytekit/requester/chatgpt_requester.py similarity index 59% rename from plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py rename to flytekit/requester/chatgpt_requester.py index ac738ec0e3..aaa4479d0a 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/agent.py +++ b/flytekit/requester/chatgpt_requester.py @@ -1,9 +1,17 @@ +from typing import Any, Dict, Optional, TypeVar + +from flytekit import FlyteContextManager +from flytekit.configuration import SerializationSettings +from flytekit.requester.base_requester import BaseRequester import json import pickle import typing from dataclasses import dataclass from typing import Optional +from google.protobuf.json_format import MessageToDict + + import aiohttp import grpc from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource @@ -13,12 +21,31 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate +T = TypeVar("T") + + +@dataclass +class ChatGPT(object): -class ChatGPTAgent(AgentBase): - def __init__(self): - super().__init__(task_type="chatgpt") + openai_organization: str = None + chatgpt_conf: Dict[str, str] = None + +class ChatGPTRequester(BaseRequester): + # TODO, + def __init__(self, name: str, task_config: ChatGPT, **kwargs): + super().__init__(name=name, task_config=task_config, **kwargs) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + job = super().get_custom() + if isinstance(self.task_config, ChatGPT): + job["chatgptConf"] = self.task_config.chatgpt_conf + job["openaiOrganization"] = self.task_config.openai_organization + + return MessageToDict(job.to_flyte_idl()) + + # TODO, Know how to write the input output, maybe like google bigquery async def async_do( self, context: grpc.ServicerContext, @@ -42,11 +69,10 @@ async def async_do( print("Do Response: ", response) return GetTaskResponse(resource=Resource(state=SUCCEEDED)) - def get_header(openai_organization: str): - token = flytekit.current_context().secrets.get("openai", "token") + token = flytekit.current_context().secrets.get("openai", "access_token") return { 'OpenAI-Organization': openai_organization, 'Authorization': f'Bearer {token}', diff --git a/flytekit/requester/requester_engine.py b/flytekit/requester/requester_engine.py new file mode 100644 index 0000000000..24e8900bdd --- /dev/null +++ b/flytekit/requester/requester_engine.py @@ -0,0 +1,56 @@ +import importlib +import typing +from typing import Optional + +import cloudpickle +import grpc +import jsonpickle +from flyteidl.admin.agent_pb2 import ( + RETRYABLE_FAILURE, + SUCCEEDED, + DoTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) + + +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.requester.base_requester import INPUTS, REQUESTER_CONFIG_PKL, REQUESTER_MODULE, REQUESTER_NAME + +T = typing.TypeVar("T") + +class RequesterEngine(AgentBase): + def __init__(self): + super().__init__(task_type="requester", asynchronous=True) + + async def async_do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> DoTaskResponse: + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + ctx = FlyteContextManager.current_context() + if inputs: + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + task_template.custom[INPUTS] = native_inputs + + meta = task_template.custom + + requester_module = importlib.import_module(name=meta[REQUESTER_MODULE]) + requester_def = getattr(requester_module, meta[REQUESTER_NAME]) + requester_config = jsonpickle.decode(meta[REQUESTER_CONFIG_PKL]) if meta.get(REQUESTER_CONFIG_PKL) else None + + cur_state = SUCCEEDED if await requester_def("requester", config=requester_config).do(**inputs) else RETRYABLE_FAILURE + + return DoTaskResponse(resource=Resource(state=cur_state, outputs=None)) + +AgentRegistry.register(RequesterEngine()) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py deleted file mode 100644 index 3973ab8213..0000000000 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -""" - currentmodule:: flytekitplugins.chatgpt -""" - -from flytekit.configuration import internal as _internal - -from .agent import ChatGPTAgent - diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py deleted file mode 100644 index 112e260dca..0000000000 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ /dev/null @@ -1 +0,0 @@ -# Can we don't implement it but still use??? \ No newline at end of file diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py deleted file mode 100644 index 5b7b96dcaa..0000000000 --- a/plugins/flytekit-openai-chatgpt/setup.py +++ /dev/null @@ -1,36 +0,0 @@ -from setuptools import setup - -PLUGIN_NAME = "chatgpt" - -microlib_name = f"flytekitplugins-{PLUGIN_NAME}" - -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] - -__version__ = "0.0.0+develop" - -setup( - name=microlib_name, - version=__version__, - author="flyteorg", - author_email="admin@flyte.org", - description="chatgpt plugin for flytekit", - namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}"], - install_requires=plugin_requires, - license="apache2", - python_requires=">=3.8", - classifiers=[ - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, -) From b22531026668b7caf76110240a1637d97bc83203 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 11 Sep 2023 16:45:52 +0800 Subject: [PATCH 05/64] Add do task interface Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 11 +++++++---- flytekit/extend/backend/base_agent.py | 5 +++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index c401816864..44ae86ba7f 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -6,6 +6,8 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + DoTaskRequest, + DoTaskResponse, GetTaskRequest, GetTaskResponse, ) @@ -84,11 +86,11 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to delete task with error {e}") - async def DoTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: + async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: try: tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(context, tmp.type) + agent = AgentRegistry.get_agent(tmp.type) logger.info(f"{agent.task_type} agent start doing the job") if agent.asynchronous: try: @@ -103,8 +105,9 @@ async def DoTask(self, request: CreateTaskRequest, context: grpc.ServicerContext agent.do, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp ) except Exception as e: - logger + logger.error(f"failed to run sync do with error {e}") + raise except Exception as e: logger.error(f"failed to do task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to create task with error {e}") \ No newline at end of file + context.set_details(f"failed to do task with error {e}") diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 1f5c14438d..2e0492004d 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -16,6 +16,7 @@ SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, + DoTaskResponse, GetTaskResponse, State, ) @@ -85,7 +86,7 @@ def do( output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> CreateTaskResponse: + ) -> DoTaskResponse: """ Return the result of executing a task. It should return error code if the task creation failed. """ @@ -123,7 +124,7 @@ def async_do( output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> CreateTaskResponse: + ) -> DoTaskResponse: """ Return the result of executing a task. It should return error code if the task creation failed. """ From 09bc23aa317571cfb4684f4797a1451dfdc02011 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 11 Sep 2023 17:45:39 +0800 Subject: [PATCH 06/64] fix lint Signed-off-by: Future Outlier --- flytekit/requester/__init__.py | 2 +- flytekit/requester/base_requester.py | 5 ++-- flytekit/requester/chatgpt_requester.py | 34 ++++++++++--------------- flytekit/requester/requester_engine.py | 9 +++---- 4 files changed, 22 insertions(+), 28 deletions(-) diff --git a/flytekit/requester/__init__.py b/flytekit/requester/__init__.py index 20760d5990..eae3462a75 100644 --- a/flytekit/requester/__init__.py +++ b/flytekit/requester/__init__.py @@ -1,2 +1,2 @@ from .base_requester import BaseRequester -from .chatgpt_requester import ChatGPTRequester \ No newline at end of file +from .chatgpt_requester import ChatGPTRequester diff --git a/flytekit/requester/base_requester.py b/flytekit/requester/base_requester.py index d1446e7727..63238062bb 100644 --- a/flytekit/requester/base_requester.py +++ b/flytekit/requester/base_requester.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional, TypeVar import jsonpickle +from flyteidl.admin.agent_pb2 import DoTaskResponse from typing_extensions import get_type_hints -from flyteidl.admin.agent_pb2 import GetTaskResponse + from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface @@ -40,7 +41,7 @@ def __init__( ) @abstractmethod - async def do(self, **kwargs) -> GetTaskResponse: + async def do(self, **kwargs) -> DoTaskResponse: raise NotImplementedError def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/requester/chatgpt_requester.py index aaa4479d0a..7383d73796 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/requester/chatgpt_requester.py @@ -1,25 +1,21 @@ -from typing import Any, Dict, Optional, TypeVar - -from flytekit import FlyteContextManager -from flytekit.configuration import SerializationSettings -from flytekit.requester.base_requester import BaseRequester import json import pickle import typing from dataclasses import dataclass -from typing import Optional - -from google.protobuf.json_format import MessageToDict - +from typing import Any, Dict, Optional, TypeVar import aiohttp import grpc -from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from google.protobuf.json_format import MessageToDict import flytekit +from flytekit import FlyteContextManager +from flytekit.configuration import SerializationSettings from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate +from flytekit.requester.base_requester import BaseRequester T = TypeVar("T") @@ -31,9 +27,7 @@ class ChatGPT(object): chatgpt_conf: Dict[str, str] = None - class ChatGPTRequester(BaseRequester): - # TODO, def __init__(self, name: str, task_config: ChatGPT, **kwargs): super().__init__(name=name, task_config=task_config, **kwargs) @@ -42,8 +36,8 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: if isinstance(self.task_config, ChatGPT): job["chatgptConf"] = self.task_config.chatgpt_conf job["openaiOrganization"] = self.task_config.openai_organization - - return MessageToDict(job.to_flyte_idl()) + + return MessageToDict(job.to_flyte_idl()) # TODO, Know how to write the input output, maybe like google bigquery async def async_do( @@ -52,7 +46,7 @@ async def async_do( output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - ) -> GetTaskResponse: + ) -> DoTaskResponse: custom = task_template.custom chatgpt_job = custom["chatgptConf"] openai_organization = custom["openaiOrganization"] @@ -68,13 +62,13 @@ async def async_do( print("Do Response: ", response) - return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) def get_header(openai_organization: str): token = flytekit.current_context().secrets.get("openai", "access_token") return { - 'OpenAI-Organization': openai_organization, - 'Authorization': f'Bearer {token}', - 'content-type': 'application/json' - } \ No newline at end of file + "OpenAI-Organization": openai_organization, + "Authorization": f"Bearer {token}", + "content-type": "application/json", + } diff --git a/flytekit/requester/requester_engine.py b/flytekit/requester/requester_engine.py index 24e8900bdd..5148c885f2 100644 --- a/flytekit/requester/requester_engine.py +++ b/flytekit/requester/requester_engine.py @@ -9,12 +9,9 @@ RETRYABLE_FAILURE, SUCCEEDED, DoTaskResponse, - DeleteTaskResponse, - GetTaskResponse, Resource, ) - from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry @@ -42,15 +39,17 @@ async def async_do( if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) task_template.custom[INPUTS] = native_inputs - meta = task_template.custom requester_module = importlib.import_module(name=meta[REQUESTER_MODULE]) requester_def = getattr(requester_module, meta[REQUESTER_NAME]) requester_config = jsonpickle.decode(meta[REQUESTER_CONFIG_PKL]) if meta.get(REQUESTER_CONFIG_PKL) else None - cur_state = SUCCEEDED if await requester_def("requester", config=requester_config).do(**inputs) else RETRYABLE_FAILURE + cur_state = ( + SUCCEEDED if await requester_def("requester", config=requester_config).do(**inputs) else RETRYABLE_FAILURE + ) return DoTaskResponse(resource=Resource(state=cur_state, outputs=None)) + AgentRegistry.register(RequesterEngine()) From 7c9dcbc9ff285b6c966b911dee07ad613c7a2a57 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 20 Sep 2023 12:19:33 +0800 Subject: [PATCH 07/64] upload data v1 Signed-off-by: Future Outlier --- flytekit/requester/chatgpt_requester.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/requester/chatgpt_requester.py index 7383d73796..9313642252 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/requester/chatgpt_requester.py @@ -3,12 +3,12 @@ import typing from dataclasses import dataclass from typing import Any, Dict, Optional, TypeVar - +import os import aiohttp import grpc from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from google.protobuf.json_format import MessageToDict - +from flytekit.core import utils import flytekit from flytekit import FlyteContextManager from flytekit.configuration import SerializationSettings @@ -16,6 +16,7 @@ from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.requester.base_requester import BaseRequester +from flytekit.core.type_engine import TypeEngine T = TypeVar("T") @@ -61,6 +62,12 @@ async def async_do( response = await resp.json() print("Do Response: ", response) + message = response.choices[0].message.content + lt = TypeEngine.to_literal_type(str) + ctx = FlyteContextManager.current_context() + message = TypeEngine.to_literal(ctx, message, str, lt).to_flyte_idl() + + utils.write_proto_to_file(message, os.path.join(output_prefix, "message")) return DoTaskResponse(resource=Resource(state=SUCCEEDED)) From 18a9e5d1d4a6eb6ac2a43673bf379ce036e9b599 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sat, 23 Sep 2023 22:07:40 +0800 Subject: [PATCH 08/64] base_requester and chatgpt_requester succeed Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 55 +++++++++++++++- flytekit/requester/__init__.py | 1 + flytekit/requester/base_requester.py | 20 ++++-- flytekit/requester/chatgpt_requester.py | 85 +++++++++++-------------- flytekit/requester/requester_engine.py | 20 ++---- flytekit/sensor/base_sensor.py | 2 +- 6 files changed, 116 insertions(+), 67 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2e0492004d..239c020546 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,4 +1,5 @@ import asyncio +import os import signal import sys import time @@ -18,13 +19,16 @@ DeleteTaskResponse, DoTaskResponse, GetTaskResponse, + Resource, State, ) +from flyteidl.core import literals_pb2 from flyteidl.core.tasks_pb2 import TaskTemplate from rich.progress import Progress from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.core import utils from flytekit.core.base_task import PythonTask from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap @@ -191,14 +195,43 @@ def execute(self, **kwargs) -> typing.Any: task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template self._agent = AgentRegistry.get_agent(task_template.type) - res = asyncio.run(self._create(task_template, kwargs)) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) + if _is_method_overridden(self._agent, "do", AgentBase) or _is_method_overridden( + self._agent, "async_do", AgentBase + ): + res = asyncio.run(self._do(task_template, kwargs)) + else: + res = asyncio.run(self._create(task_template, kwargs)) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) if res.resource.state != SUCCEEDED: raise Exception(f"Failed to run the task {self._entity.name}") return LiteralMap.from_flyte_idl(res.resource.outputs) + async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None) -> DoTaskResponse: + ctx = FlyteContext.current_context() + grpc_ctx = _get_grpc_context() + + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) + inputs = LiteralMap(literals) if literals else None + output_prefix = ctx.file_access.get_random_local_directory() + + progress = Progress(transient=True) + task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) + with progress: + progress.start_task(task) + if self._agent.asynchronous: + res = await self._agent.async_do(grpc_ctx, output_prefix, task_template, inputs) + else: + res = self._agent.do(grpc_ctx, output_prefix, task_template, inputs) + + output_filename = os.path.join(output_prefix, "do.proto") + outpus = utils.load_proto_from_file(literals_pb2.LiteralMap, output_filename) + + return DoTaskResponse(resource=Resource(state=res.resource.state, outputs=outpus)) + async def _create( self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None ) -> CreateTaskResponse: @@ -256,3 +289,21 @@ def _get_grpc_context(): grpc_ctx = MagicMock(spec=grpc.ServicerContext) return grpc_ctx + + +def _is_method_overridden(instance, method_name, base_class): + """ + Check if a method with the given method_name is overridden in instance's class + relative to the given base_class. + """ + method = getattr(instance, method_name) + base_method = getattr(base_class, method_name) + + # Check if method is bound method or just a function + if hasattr(method, "__func__"): + method = method.__func__ + + if hasattr(base_method, "__func__"): + base_method = base_method.__func__ + + return method is not base_method diff --git a/flytekit/requester/__init__.py b/flytekit/requester/__init__.py index eae3462a75..35aeeef0f5 100644 --- a/flytekit/requester/__init__.py +++ b/flytekit/requester/__init__.py @@ -1,2 +1,3 @@ from .base_requester import BaseRequester from .chatgpt_requester import ChatGPTRequester +from .requester_engine import RequesterEngine diff --git a/flytekit/requester/base_requester.py b/flytekit/requester/base_requester.py index 63238062bb..12047d0665 100644 --- a/flytekit/requester/base_requester.py +++ b/flytekit/requester/base_requester.py @@ -22,29 +22,41 @@ class BaseRequester(AsyncAgentExecutorMixin, PythonTask): """ TODO: Write the docstring - Base class for all requesters. Sensors are tasks that are designed to run forever, and periodically check for some - condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the - sensor agent, and not by the Flyte engine. """ def __init__( self, name: str, + requester_config: Optional[T] = None, task_type: str = "requester", **kwargs, ): + type_hints = get_type_hints(self.do, include_extras=True) + signature = inspect.signature(self.do) + inputs = collections.OrderedDict() + outputs = collections.OrderedDict() + + for k, _ in signature.parameters.items(): # type: ignore + annotation = type_hints.get(k, None) + inputs[k] = annotation + + if "return" in type_hints: + outputs["o0"] = type_hints["return"] super().__init__( task_type=task_type, name=name, + task_config=None, + interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) + self._requester_config = requester_config @abstractmethod async def do(self, **kwargs) -> DoTaskResponse: raise NotImplementedError - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { REQUESTER_MODULE: type(self).__module__, REQUESTER_NAME: type(self).__name__, diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/requester/chatgpt_requester.py index 9313642252..04c2c089a6 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/requester/chatgpt_requester.py @@ -1,73 +1,64 @@ import json -import pickle -import typing -from dataclasses import dataclass -from typing import Any, Dict, Optional, TypeVar import os +from typing import Any, Dict + import aiohttp -import grpc from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource -from google.protobuf.json_format import MessageToDict -from flytekit.core import utils + import flytekit from flytekit import FlyteContextManager -from flytekit.configuration import SerializationSettings -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.core import utils +from flytekit.core.type_engine import TypeEngine from flytekit.models.literals import LiteralMap -from flytekit.models.task import TaskTemplate from flytekit.requester.base_requester import BaseRequester -from flytekit.core.type_engine import TypeEngine - -T = TypeVar("T") - - -@dataclass -class ChatGPT(object): - - openai_organization: str = None - chatgpt_conf: Dict[str, str] = None class ChatGPTRequester(BaseRequester): - def __init__(self, name: str, task_config: ChatGPT, **kwargs): - super().__init__(name=name, task_config=task_config, **kwargs) + """ + TODO: Write the docstring + """ - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = super().get_custom() - if isinstance(self.task_config, ChatGPT): - job["chatgptConf"] = self.task_config.chatgpt_conf - job["openaiOrganization"] = self.task_config.openai_organization + _openai_organization: str = None + _chatgpt_conf: Dict[str, Any] = None - return MessageToDict(job.to_flyte_idl()) + # TODO, such as Value Error + def __init__(self, name: str, config: Dict[str, Any], **kwargs): + super().__init__(name=name, requester_config=config, **kwargs) + self._openai_organization = config["openai_organization"] + self._chatgpt_conf = config["chatgpt_conf"] - # TODO, Know how to write the input output, maybe like google bigquery - async def async_do( + async def do( self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, + output_prefix: str = None, + message: str = None, ) -> DoTaskResponse: - custom = task_template.custom - chatgpt_job = custom["chatgptConf"] - openai_organization = custom["openaiOrganization"] - + self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] openai_url = "https://api.openai.com/v1/chat/completions" - data = json.dumps(chatgpt_job) + data = json.dumps(self._chatgpt_conf) async with aiohttp.ClientSession() as session: - async with session.post(openai_url, headers=get_header(openai_organization), data=data) as resp: + async with session.post( + openai_url, headers=get_header(openai_organization=self._openai_organization), data=data + ) as resp: if resp.status != 200: - raise Exception(f"Failed to execute chathpt job with error: {resp.reason}") + raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") response = await resp.json() - print("Do Response: ", response) - message = response.choices[0].message.content - lt = TypeEngine.to_literal_type(str) + message = response["choices"][0]["message"]["content"] ctx = FlyteContextManager.current_context() - message = TypeEngine.to_literal(ctx, message, str, lt).to_flyte_idl() - - utils.write_proto_to_file(message, os.path.join(output_prefix, "message")) + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + message, + type(message), + TypeEngine.to_literal_type(type(message)), + ) + } + ).to_flyte_idl() + + output_filename = os.path.join(output_prefix, "do.proto") + utils.write_proto_to_file(outputs, output_filename) return DoTaskResponse(resource=Resource(state=SUCCEEDED)) diff --git a/flytekit/requester/requester_engine.py b/flytekit/requester/requester_engine.py index 5148c885f2..ee64742204 100644 --- a/flytekit/requester/requester_engine.py +++ b/flytekit/requester/requester_engine.py @@ -2,15 +2,9 @@ import typing from typing import Optional -import cloudpickle import grpc import jsonpickle -from flyteidl.admin.agent_pb2 import ( - RETRYABLE_FAILURE, - SUCCEEDED, - DoTaskResponse, - Resource, -) +from flyteidl.admin.agent_pb2 import DoTaskResponse from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine @@ -21,6 +15,7 @@ T = typing.TypeVar("T") + class RequesterEngine(AgentBase): def __init__(self): super().__init__(task_type="requester", asynchronous=True) @@ -39,17 +34,16 @@ async def async_do( if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) task_template.custom[INPUTS] = native_inputs + else: + raise ValueError("Requester needs a input!") + meta = task_template.custom requester_module = importlib.import_module(name=meta[REQUESTER_MODULE]) requester_def = getattr(requester_module, meta[REQUESTER_NAME]) requester_config = jsonpickle.decode(meta[REQUESTER_CONFIG_PKL]) if meta.get(REQUESTER_CONFIG_PKL) else None - - cur_state = ( - SUCCEEDED if await requester_def("requester", config=requester_config).do(**inputs) else RETRYABLE_FAILURE - ) - - return DoTaskResponse(resource=Resource(state=cur_state, outputs=None)) + inputs = meta.get(INPUTS, {}) + return await requester_def("requester", config=requester_config).do(output_prefix=output_prefix, **inputs) AgentRegistry.register(RequesterEngine()) diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 60beb6aa2b..0e40055ea5 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -35,7 +35,7 @@ def __init__( type_hints = get_type_hints(self.poke, include_extras=True) signature = inspect.signature(self.poke) inputs = collections.OrderedDict() - for k, v in signature.parameters.items(): # type: ignore + for k, _ in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) inputs[k] = annotation From 8de3fa801adfe8784c51a1a2fb0b1cbce33c1060 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 26 Sep 2023 16:10:31 +0800 Subject: [PATCH 09/64] chatgpt requester v2 Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 41 ++++++----- flytekit/extend/backend/base_agent.py | 88 +++++++++++------------- flytekit/requester/base_requester.py | 7 +- flytekit/requester/chatgpt_requester.py | 11 ++- flytekit/requester/requester_engine.py | 7 +- 5 files changed, 70 insertions(+), 84 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index fd079c66f2..2af8c37080 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -117,7 +117,7 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon logger.error(f"failed to run async delete with error {e}") raise try: - res = asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta) + res = await asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta) request_success_count.labels(task_type=request.task_type, operation=delete_operation).inc() return res except Exception as e: @@ -130,27 +130,32 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: try: - tmp = TaskTemplate.from_flyte_idl(request.template) - inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(tmp.type) - logger.info(f"{agent.task_type} agent start doing the job") - if agent.asynchronous: + with request_latency.labels(task_type=request.task_type, operation="do").time(): + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + agent = AgentRegistry.get_agent(tmp.type) + logger.info(f"{agent.task_type} agent start doing the job") + if agent.asynchronous: + try: + res = await agent.async_do( + context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + ) + request_success_count.labels(task_type=request.task_type, operation=do_operation).inc() + return res + except Exception as e: + logger.error(f"failed to run async do with error {e}") + raise e try: - return await agent.async_do( - context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + res = await asyncio.to_thread( + agent.do, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp ) + request_success_count.labels(task_type=request.task_type, operation=do_operation).inc() + return res except Exception as e: - logger.error(f"failed to run async do with error {e}") - raise e - try: - return await asyncio.to_thread( - agent.do, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp - ) - except Exception as e: - logger.error(f"failed to run sync do with error {e}") - raise + logger.error(f"failed to run sync do with error {e}") + raise except Exception as e: logger.error(f"failed to do task with error {e}") context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"failed to do task with error {e}") - + request_failure_count.labels(task_type=request.task_type, operation=do_operation).inc() diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 239c020546..bf6f2f3f2c 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,5 +1,4 @@ import asyncio -import os import signal import sys import time @@ -26,6 +25,7 @@ from flyteidl.core.tasks_pb2 import TaskTemplate from rich.progress import Progress +import flytekit from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core import utils @@ -122,7 +122,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes """ raise NotImplementedError - def async_do( + async def async_do( self, context: grpc.ServicerContext, output_prefix: str, @@ -135,6 +135,36 @@ def async_do( raise NotImplementedError +class RequesterAgent(AgentBase): + def __init__(self, task_type: str = "requester"): + super().__init__(task_type=task_type) + + async def _do( + self, + entity: PythonTask, + task_template: TaskTemplate, + agent: AgentBase, + inputs: typing.Dict[str, typing.Any] = None, + ) -> DoTaskResponse: + ctx = FlyteContext.current_context() + grpc_ctx = _get_grpc_context() + + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) + inputs = LiteralMap(literals) if literals else None + output_prefix = ctx.file_access.get_random_local_path() + + progress = Progress(transient=True) + task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) + with progress: + progress.start_task(task) + res = await agent.async_do(grpc_ctx, output_prefix, task_template, inputs) + outpus = utils.load_proto_from_file(literals_pb2.LiteralMap, output_prefix) + + return DoTaskResponse(resource=Resource(state=res.resource.state, outputs=outpus)) + + class AgentRegistry(object): """ This is the registry for all agents. The agent service will look up the agent @@ -178,6 +208,10 @@ def is_terminal_state(state: State) -> bool: return state in [SUCCEEDED, RETRYABLE_FAILURE, PERMANENT_FAILURE] +def get_secret(secret_key: str) -> str: + return flytekit.current_context().secrets.get("flyteagent", secret_key) + + class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. @@ -195,10 +229,8 @@ def execute(self, **kwargs) -> typing.Any: task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template self._agent = AgentRegistry.get_agent(task_template.type) - if _is_method_overridden(self._agent, "do", AgentBase) or _is_method_overridden( - self._agent, "async_do", AgentBase - ): - res = asyncio.run(self._do(task_template, kwargs)) + if isinstance(self._agent, RequesterAgent): + res = asyncio.run(self._agent._do(self._entity, task_template, self._agent, kwargs)) else: res = asyncio.run(self._create(task_template, kwargs)) res = asyncio.run(self._get(resource_meta=res.resource_meta)) @@ -208,30 +240,6 @@ def execute(self, **kwargs) -> typing.Any: return LiteralMap.from_flyte_idl(res.resource.outputs) - async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None) -> DoTaskResponse: - ctx = FlyteContext.current_context() - grpc_ctx = _get_grpc_context() - - literals = {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) - inputs = LiteralMap(literals) if literals else None - output_prefix = ctx.file_access.get_random_local_directory() - - progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) - with progress: - progress.start_task(task) - if self._agent.asynchronous: - res = await self._agent.async_do(grpc_ctx, output_prefix, task_template, inputs) - else: - res = self._agent.do(grpc_ctx, output_prefix, task_template, inputs) - - output_filename = os.path.join(output_prefix, "do.proto") - outpus = utils.load_proto_from_file(literals_pb2.LiteralMap, output_filename) - - return DoTaskResponse(resource=Resource(state=res.resource.state, outputs=outpus)) - async def _create( self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None ) -> CreateTaskResponse: @@ -284,26 +292,8 @@ def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> sys.exit(1) -def _get_grpc_context(): +def _get_grpc_context() -> grpc.ServicerContext: from unittest.mock import MagicMock grpc_ctx = MagicMock(spec=grpc.ServicerContext) return grpc_ctx - - -def _is_method_overridden(instance, method_name, base_class): - """ - Check if a method with the given method_name is overridden in instance's class - relative to the given base_class. - """ - method = getattr(instance, method_name) - base_method = getattr(base_class, method_name) - - # Check if method is bound method or just a function - if hasattr(method, "__func__"): - method = method.__func__ - - if hasattr(base_method, "__func__"): - base_method = base_method.__func__ - - return method is not base_method diff --git a/flytekit/requester/base_requester.py b/flytekit/requester/base_requester.py index 12047d0665..bd28bc6825 100644 --- a/flytekit/requester/base_requester.py +++ b/flytekit/requester/base_requester.py @@ -34,15 +34,12 @@ def __init__( type_hints = get_type_hints(self.do, include_extras=True) signature = inspect.signature(self.do) inputs = collections.OrderedDict() - outputs = collections.OrderedDict() + outputs = collections.OrderedDict({"o0": DoTaskResponse}) for k, _ in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) inputs[k] = annotation - if "return" in type_hints: - outputs["o0"] = type_hints["return"] - super().__init__( task_type=task_type, name=name, @@ -53,7 +50,7 @@ def __init__( self._requester_config = requester_config @abstractmethod - async def do(self, **kwargs) -> DoTaskResponse: + async def async_do(self, **kwargs) -> DoTaskResponse: raise NotImplementedError def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/requester/chatgpt_requester.py index 04c2c089a6..5b1711e9a3 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/requester/chatgpt_requester.py @@ -1,14 +1,13 @@ import json -import os from typing import Any, Dict import aiohttp from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource -import flytekit from flytekit import FlyteContextManager from flytekit.core import utils from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import get_secret from flytekit.models.literals import LiteralMap from flytekit.requester.base_requester import BaseRequester @@ -21,7 +20,7 @@ class ChatGPTRequester(BaseRequester): _openai_organization: str = None _chatgpt_conf: Dict[str, Any] = None - # TODO, such as Value Error + # TODO, Add Value Error def __init__(self, name: str, config: Dict[str, Any], **kwargs): super().__init__(name=name, requester_config=config, **kwargs) self._openai_organization = config["openai_organization"] @@ -56,15 +55,13 @@ async def do( ) } ).to_flyte_idl() - - output_filename = os.path.join(output_prefix, "do.proto") - utils.write_proto_to_file(outputs, output_filename) + utils.write_proto_to_file(outputs, output_prefix) return DoTaskResponse(resource=Resource(state=SUCCEEDED)) def get_header(openai_organization: str): - token = flytekit.current_context().secrets.get("openai", "access_token") + token = get_secret(secret_key="OPENAI_ACCESS_TOKEN") return { "OpenAI-Organization": openai_organization, "Authorization": f"Bearer {token}", diff --git a/flytekit/requester/requester_engine.py b/flytekit/requester/requester_engine.py index ee64742204..58af293374 100644 --- a/flytekit/requester/requester_engine.py +++ b/flytekit/requester/requester_engine.py @@ -8,7 +8,7 @@ from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.extend.backend.base_agent import AgentRegistry, RequesterAgent from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate from flytekit.requester.base_requester import INPUTS, REQUESTER_CONFIG_PKL, REQUESTER_MODULE, REQUESTER_NAME @@ -16,10 +16,7 @@ T = typing.TypeVar("T") -class RequesterEngine(AgentBase): - def __init__(self): - super().__init__(task_type="requester", asynchronous=True) - +class RequesterEngine(RequesterAgent): async def async_do( self, context: grpc.ServicerContext, From be7d22d1dae0376b3f6f461eea507d4c053e1c3d Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 26 Sep 2023 20:03:28 +0800 Subject: [PATCH 10/64] output type Signed-off-by: Future Outlier --- flytekit/requester/base_requester.py | 3 ++- flytekit/requester/chatgpt_requester.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/flytekit/requester/base_requester.py b/flytekit/requester/base_requester.py index bd28bc6825..2b394175b8 100644 --- a/flytekit/requester/base_requester.py +++ b/flytekit/requester/base_requester.py @@ -29,12 +29,13 @@ def __init__( name: str, requester_config: Optional[T] = None, task_type: str = "requester", + return_type: Optional[T] = None, **kwargs, ): type_hints = get_type_hints(self.do, include_extras=True) signature = inspect.signature(self.do) inputs = collections.OrderedDict() - outputs = collections.OrderedDict({"o0": DoTaskResponse}) + outputs = collections.OrderedDict({"o0": return_type}) for k, _ in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/requester/chatgpt_requester.py index 5b1711e9a3..383a6d7cdd 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/requester/chatgpt_requester.py @@ -22,7 +22,7 @@ class ChatGPTRequester(BaseRequester): # TODO, Add Value Error def __init__(self, name: str, config: Dict[str, Any], **kwargs): - super().__init__(name=name, requester_config=config, **kwargs) + super().__init__(name=name, requester_config=config, return_type=str, **kwargs) self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] From bca202b1e732a121da8a553345c9e28b4639d8f5 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 27 Sep 2023 22:16:21 +0800 Subject: [PATCH 11/64] dispatcher Signed-off-by: Future Outlier --- flytekit/__init__.py | 1 + flytekit/dispatcher/__init__.py | 3 ++ .../base_dispatcher.py} | 22 +++++++------- .../chatgpt_dispatcher.py} | 29 ++++++++++--------- .../dispatcher_engine.py} | 20 ++++++------- flytekit/extend/backend/agent_service.py | 11 +++---- flytekit/extend/backend/base_agent.py | 17 +++++------ flytekit/requester/__init__.py | 3 -- 8 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 flytekit/dispatcher/__init__.py rename flytekit/{requester/base_requester.py => dispatcher/base_dispatcher.py} (72%) rename flytekit/{requester/chatgpt_requester.py => dispatcher/chatgpt_dispatcher.py} (66%) rename flytekit/{requester/requester_engine.py => dispatcher/dispatcher_engine.py} (57%) delete mode 100644 flytekit/requester/__init__.py diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 75037d3370..815215c1ce 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -237,6 +237,7 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.sensor.sensor_engine import SensorEngine +from flytekit.dispatcher.dispatcher_engine import DispatcherEngine from flytekit.types import directory, file, iterator from flytekit.types.structured.structured_dataset import ( StructuredDataset, diff --git a/flytekit/dispatcher/__init__.py b/flytekit/dispatcher/__init__.py new file mode 100644 index 0000000000..eeb2ed1457 --- /dev/null +++ b/flytekit/dispatcher/__init__.py @@ -0,0 +1,3 @@ +from .base_dispatcher import BaseDispatcher +from .chatgpt_dispatcher import ChatGPTDispatcher +from .dispatcher_engine import DispatcherEngine diff --git a/flytekit/requester/base_requester.py b/flytekit/dispatcher/base_dispatcher.py similarity index 72% rename from flytekit/requester/base_requester.py rename to flytekit/dispatcher/base_dispatcher.py index 2b394175b8..4f6aef7c23 100644 --- a/flytekit/requester/base_requester.py +++ b/flytekit/dispatcher/base_dispatcher.py @@ -13,13 +13,13 @@ from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin T = TypeVar("T") -REQUESTER_MODULE = "requester_module" -REQUESTER_NAME = "requester_name" -REQUESTER_CONFIG_PKL = "requester_config_pkl" +DISPATCHER_MODULE = "dispatcher_module" +DISPATCHER_NAME = "dispatcher_name" +DISPATCHER_CONFIG_PKL = "dispatcher_config_pkl" INPUTS = "inputs" -class BaseRequester(AsyncAgentExecutorMixin, PythonTask): +class BaseDispatcher(AsyncAgentExecutorMixin, PythonTask): """ TODO: Write the docstring """ @@ -27,8 +27,8 @@ class BaseRequester(AsyncAgentExecutorMixin, PythonTask): def __init__( self, name: str, - requester_config: Optional[T] = None, - task_type: str = "requester", + dispatcher_config: Optional[T] = None, + task_type: str = "dispatcher", return_type: Optional[T] = None, **kwargs, ): @@ -48,7 +48,7 @@ def __init__( interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) - self._requester_config = requester_config + self._dispatcher_config = dispatcher_config @abstractmethod async def async_do(self, **kwargs) -> DoTaskResponse: @@ -56,9 +56,9 @@ async def async_do(self, **kwargs) -> DoTaskResponse: def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { - REQUESTER_MODULE: type(self).__module__, - REQUESTER_NAME: type(self).__name__, + DISPATCHER_MODULE: type(self).__module__, + DISPATCHER_NAME: type(self).__name__, } - if self._requester_config is not None: - cfg[REQUESTER_CONFIG_PKL] = jsonpickle.encode(self._requester_config) + if self._dispatcher_config is not None: + cfg[DISPATCHER_CONFIG_PKL] = jsonpickle.encode(self._dispatcher_config) return cfg diff --git a/flytekit/requester/chatgpt_requester.py b/flytekit/dispatcher/chatgpt_dispatcher.py similarity index 66% rename from flytekit/requester/chatgpt_requester.py rename to flytekit/dispatcher/chatgpt_dispatcher.py index 383a6d7cdd..64804365d3 100644 --- a/flytekit/requester/chatgpt_requester.py +++ b/flytekit/dispatcher/chatgpt_dispatcher.py @@ -9,10 +9,10 @@ from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import get_secret from flytekit.models.literals import LiteralMap -from flytekit.requester.base_requester import BaseRequester +from flytekit.dispatcher.base_dispatcher import BaseDispatcher -class ChatGPTRequester(BaseRequester): +class ChatGPTDispatcher(BaseDispatcher): """ TODO: Write the docstring """ @@ -22,28 +22,30 @@ class ChatGPTRequester(BaseRequester): # TODO, Add Value Error def __init__(self, name: str, config: Dict[str, Any], **kwargs): - super().__init__(name=name, requester_config=config, return_type=str, **kwargs) + super().__init__(name=name, dispatcher_config=config, return_type=str, **kwargs) self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] async def do( self, - output_prefix: str = None, message: str = None, ) -> DoTaskResponse: self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] openai_url = "https://api.openai.com/v1/chat/completions" data = json.dumps(self._chatgpt_conf) - async with aiohttp.ClientSession() as session: - async with session.post( - openai_url, headers=get_header(openai_organization=self._openai_organization), data=data - ) as resp: - if resp.status != 200: - raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") - response = await resp.json() + message = "TEST SYNC PLUGIN" + + # async with aiohttp.ClientSession() as session: + # async with session.post( + # openai_url, headers=get_header(openai_organization=self._openai_organization), data=data + # ) as resp: + # if resp.status != 200: + # raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") + # response = await resp.json() + + # message = response["choices"][0]["message"]["content"] - message = response["choices"][0]["message"]["content"] ctx = FlyteContextManager.current_context() outputs = LiteralMap( { @@ -55,9 +57,8 @@ async def do( ) } ).to_flyte_idl() - utils.write_proto_to_file(outputs, output_prefix) - return DoTaskResponse(resource=Resource(state=SUCCEEDED)) + return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) def get_header(openai_organization: str): diff --git a/flytekit/requester/requester_engine.py b/flytekit/dispatcher/dispatcher_engine.py similarity index 57% rename from flytekit/requester/requester_engine.py rename to flytekit/dispatcher/dispatcher_engine.py index 58af293374..d58b86eac3 100644 --- a/flytekit/requester/requester_engine.py +++ b/flytekit/dispatcher/dispatcher_engine.py @@ -8,21 +8,21 @@ from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import AgentRegistry, RequesterAgent +from flytekit.extend.backend.base_agent import AgentRegistry, DispatcherAgent from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.requester.base_requester import INPUTS, REQUESTER_CONFIG_PKL, REQUESTER_MODULE, REQUESTER_NAME +from flytekit.dispatcher.base_dispatcher import INPUTS, DISPATCHER_CONFIG_PKL, DISPATCHER_MODULE, DISPATCHER_NAME T = typing.TypeVar("T") -class RequesterEngine(RequesterAgent): +class DispatcherEngine(DispatcherAgent): async def async_do( self, context: grpc.ServicerContext, - output_prefix: str, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, + output_prefix: Optional[str] = None, ) -> DoTaskResponse: python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() @@ -32,15 +32,15 @@ async def async_do( native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) task_template.custom[INPUTS] = native_inputs else: - raise ValueError("Requester needs a input!") + raise ValueError("Dispatcher needs a input!") meta = task_template.custom - requester_module = importlib.import_module(name=meta[REQUESTER_MODULE]) - requester_def = getattr(requester_module, meta[REQUESTER_NAME]) - requester_config = jsonpickle.decode(meta[REQUESTER_CONFIG_PKL]) if meta.get(REQUESTER_CONFIG_PKL) else None + dispatcher_module = importlib.import_module(name=meta[DISPATCHER_MODULE]) + dispatcher_def = getattr(dispatcher_module, meta[DISPATCHER_NAME]) + dispatcher_config = jsonpickle.decode(meta[DISPATCHER_CONFIG_PKL]) if meta.get(DISPATCHER_CONFIG_PKL) else None inputs = meta.get(INPUTS, {}) - return await requester_def("requester", config=requester_config).do(output_prefix=output_prefix, **inputs) + return await dispatcher_def("dispatcher", config=dispatcher_config).do(**inputs) -AgentRegistry.register(RequesterEngine()) +AgentRegistry.register(DispatcherEngine()) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 2af8c37080..7f26eba829 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -130,26 +130,27 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: try: - with request_latency.labels(task_type=request.task_type, operation="do").time(): + with request_latency.labels(task_type=request.template.type, operation=do_operation).time(): tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + input_literal_size.labels(task_type=tmp.type).observe(request.inputs.ByteSize()) agent = AgentRegistry.get_agent(tmp.type) logger.info(f"{agent.task_type} agent start doing the job") if agent.asynchronous: try: res = await agent.async_do( - context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + context=context, inputs=inputs, task_template=tmp ) - request_success_count.labels(task_type=request.task_type, operation=do_operation).inc() + request_success_count.labels(task_type=tmp.type, operation=do_operation).inc() return res except Exception as e: logger.error(f"failed to run async do with error {e}") raise e try: res = await asyncio.to_thread( - agent.do, context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + agent.do, context=context, inputs=inputs, task_template=tmp ) - request_success_count.labels(task_type=request.task_type, operation=do_operation).inc() + request_success_count.labels(task_type=tmp.type, operation=do_operation).inc() return res except Exception as e: logger.error(f"failed to run sync do with error {e}") diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index bf6f2f3f2c..21de59f34c 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -125,9 +125,9 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async def async_do( self, context: grpc.ServicerContext, - output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, + output_prefix: typing.Optional[str] = None, ) -> DoTaskResponse: """ Return the result of executing a task. It should return error code if the task creation failed. @@ -135,9 +135,9 @@ async def async_do( raise NotImplementedError -class RequesterAgent(AgentBase): - def __init__(self, task_type: str = "requester"): - super().__init__(task_type=task_type) +class DispatcherAgent(AgentBase): + def __init__(self): + super().__init__(task_type="dispatcher") async def _do( self, @@ -153,16 +153,13 @@ async def _do( for k, v in inputs.items(): literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) inputs = LiteralMap(literals) if literals else None - output_prefix = ctx.file_access.get_random_local_path() + output_prefix = ctx.file_access.get_random_local_directory() progress = Progress(transient=True) task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) with progress: progress.start_task(task) - res = await agent.async_do(grpc_ctx, output_prefix, task_template, inputs) - outpus = utils.load_proto_from_file(literals_pb2.LiteralMap, output_prefix) - - return DoTaskResponse(resource=Resource(state=res.resource.state, outputs=outpus)) + return await agent.async_do(context=grpc_ctx, output_prefix=output_prefix, task_template=task_template, inputs=inputs) class AgentRegistry(object): @@ -229,7 +226,7 @@ def execute(self, **kwargs) -> typing.Any: task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template self._agent = AgentRegistry.get_agent(task_template.type) - if isinstance(self._agent, RequesterAgent): + if isinstance(self._agent, DispatcherAgent): res = asyncio.run(self._agent._do(self._entity, task_template, self._agent, kwargs)) else: res = asyncio.run(self._create(task_template, kwargs)) diff --git a/flytekit/requester/__init__.py b/flytekit/requester/__init__.py deleted file mode 100644 index 35aeeef0f5..0000000000 --- a/flytekit/requester/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base_requester import BaseRequester -from .chatgpt_requester import ChatGPTRequester -from .requester_engine import RequesterEngine From cf9ff07b60afe57511fc8fadd1f3f8ae871b19ae Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sat, 30 Sep 2023 22:16:33 +0800 Subject: [PATCH 12/64] move to plugins directory Signed-off-by: Future Outlier --- flytekit/dispatcher/__init__.py | 1 - .../flytekitplugins/chatgpt/__init__.py | 13 +++++++ .../flytekitplugins/chatgpt/task.py | 1 - .../flytekit-openai-chatgpt/requirements.in | 2 ++ .../flytekit-openai-chatgpt/requirements.txt | 0 plugins/flytekit-openai-chatgpt/setup.py | 36 +++++++++++++++++++ .../flytekit-openai-chatgpt/tests/__init__.py | 0 .../tests/test_chatgpt.py | 0 8 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py rename flytekit/dispatcher/chatgpt_dispatcher.py => plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py (98%) create mode 100644 plugins/flytekit-openai-chatgpt/requirements.in create mode 100644 plugins/flytekit-openai-chatgpt/requirements.txt create mode 100644 plugins/flytekit-openai-chatgpt/setup.py create mode 100644 plugins/flytekit-openai-chatgpt/tests/__init__.py create mode 100644 plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py diff --git a/flytekit/dispatcher/__init__.py b/flytekit/dispatcher/__init__.py index eeb2ed1457..7d2a507dbf 100644 --- a/flytekit/dispatcher/__init__.py +++ b/flytekit/dispatcher/__init__.py @@ -1,3 +1,2 @@ from .base_dispatcher import BaseDispatcher -from .chatgpt_dispatcher import ChatGPTDispatcher from .dispatcher_engine import DispatcherEngine diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..ca0bb68c6f --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.chatgpt + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ChatGPTDispatcher +""" + +from .task import ChatGPTDispatcher diff --git a/flytekit/dispatcher/chatgpt_dispatcher.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py similarity index 98% rename from flytekit/dispatcher/chatgpt_dispatcher.py rename to plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 64804365d3..71e5683834 100644 --- a/flytekit/dispatcher/chatgpt_dispatcher.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -5,7 +5,6 @@ from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager -from flytekit.core import utils from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import get_secret from flytekit.models.literals import LiteralMap diff --git a/plugins/flytekit-openai-chatgpt/requirements.in b/plugins/flytekit-openai-chatgpt/requirements.in new file mode 100644 index 0000000000..03afde6b3a --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-openai-chatgpt diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py new file mode 100644 index 0000000000..a85fbec0fa --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -0,0 +1,36 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Bigquery plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-openai-chatgpt/tests/__init__.py b/plugins/flytekit-openai-chatgpt/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py new file mode 100644 index 0000000000..e69de29bb2 From 3b13b48698d07e6c3a7a762069027becfd66f6ce Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 2 Oct 2023 17:12:14 +0800 Subject: [PATCH 13/64] do task agent service Signed-off-by: Future Outlier --- flytekit/dispatcher/dispatcher_engine.py | 2 +- flytekit/extend/backend/agent_service.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/flytekit/dispatcher/dispatcher_engine.py b/flytekit/dispatcher/dispatcher_engine.py index d58b86eac3..c45e4a35bf 100644 --- a/flytekit/dispatcher/dispatcher_engine.py +++ b/flytekit/dispatcher/dispatcher_engine.py @@ -8,10 +8,10 @@ from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.dispatcher.base_dispatcher import DISPATCHER_CONFIG_PKL, DISPATCHER_MODULE, DISPATCHER_NAME, INPUTS from flytekit.extend.backend.base_agent import AgentRegistry, DispatcherAgent from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.dispatcher.base_dispatcher import INPUTS, DISPATCHER_CONFIG_PKL, DISPATCHER_MODULE, DISPATCHER_NAME T = typing.TypeVar("T") diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 7f26eba829..942c52a3a0 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -138,18 +138,14 @@ async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> logger.info(f"{agent.task_type} agent start doing the job") if agent.asynchronous: try: - res = await agent.async_do( - context=context, inputs=inputs, task_template=tmp - ) + res = await agent.async_do(context=context, inputs=inputs, task_template=tmp) request_success_count.labels(task_type=tmp.type, operation=do_operation).inc() return res except Exception as e: logger.error(f"failed to run async do with error {e}") raise e try: - res = await asyncio.to_thread( - agent.do, context=context, inputs=inputs, task_template=tmp - ) + res = await asyncio.get_running_loop().run_in_executor(None, agent.do, context, inputs, tmp) request_success_count.labels(task_type=tmp.type, operation=do_operation).inc() return res except Exception as e: From 324996398a7a07334bc8f978f64d2d6cad7aff06 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 10 Oct 2023 22:09:09 +0800 Subject: [PATCH 14/64] merge master Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 17 ++++++++++++++++- .../flytekitplugins/chatgpt/task.py | 4 ++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 2d4cb73369..c22e3e38fd 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -47,7 +47,7 @@ def agent_exception_handler(func): async def wrapper( self, - request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], + request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest, DoTaskRequest], context: grpc.ServicerContext, *args, **kwargs, @@ -63,6 +63,11 @@ async def wrapper( elif isinstance(request, DeleteTaskRequest): task_type = request.task_type operation = delete_operation + elif isinstance(request, DoTaskRequest): + task_type = request.template.type + operation = do_operation + if request.inputs: + input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") @@ -125,3 +130,13 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon if agent.asynchronous: return await agent.async_delete(context=context, resource_meta=request.resource_meta) return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta) + + @agent_exception_handler + async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + agent = AgentRegistry.get_agent(tmp.type) + logger.info(f"{tmp.type} agent start doing the job") + if agent.asynchronous: + return await agent.async_do(context=context, inputs=inputs, task_template=tmp) + return await asyncio.get_running_loop().run_in_executor(None, agent.do, context, "", inputs, tmp) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 71e5683834..a49a89d9f1 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -6,7 +6,7 @@ from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import get_secret +from flytekit.extend.backend.base_agent import get_agent_secret from flytekit.models.literals import LiteralMap from flytekit.dispatcher.base_dispatcher import BaseDispatcher @@ -61,7 +61,7 @@ async def do( def get_header(openai_organization: str): - token = get_secret(secret_key="OPENAI_ACCESS_TOKEN") + token = get_agent_secret(secret_key="OPENAI_ACCESS_TOKEN") return { "OpenAI-Organization": openai_organization, "Authorization": f"Bearer {token}", From c91ee4cb1fba7290b48b575938fa66d29cf53722 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 15 Oct 2023 00:36:18 -0700 Subject: [PATCH 15/64] refactor Signed-off-by: Kevin Su --- flytekit/__init__.py | 2 +- .../external_api_task.py} | 35 +++--- flytekit/dispatcher/__init__.py | 2 - flytekit/extend/backend/base_agent.py | 107 +++++++----------- .../backend/task_executor.py} | 24 ++-- .../flytekitplugins/chatgpt/task.py | 5 +- 6 files changed, 78 insertions(+), 97 deletions(-) rename flytekit/{dispatcher/base_dispatcher.py => core/external_api_task.py} (55%) delete mode 100644 flytekit/dispatcher/__init__.py rename flytekit/{dispatcher/dispatcher_engine.py => extend/backend/task_executor.py} (54%) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 815215c1ce..96dc7b8ac1 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -228,6 +228,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck +from flytekit.extend.backend.task_executor import DispatcherEngine from flytekit.image_spec import ImageSpec from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -237,7 +238,6 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.sensor.sensor_engine import SensorEngine -from flytekit.dispatcher.dispatcher_engine import DispatcherEngine from flytekit.types import directory, file, iterator from flytekit.types.structured.structured_dataset import ( StructuredDataset, diff --git a/flytekit/dispatcher/base_dispatcher.py b/flytekit/core/external_api_task.py similarity index 55% rename from flytekit/dispatcher/base_dispatcher.py rename to flytekit/core/external_api_task.py index 4f6aef7c23..6d07ae7cff 100644 --- a/flytekit/dispatcher/base_dispatcher.py +++ b/flytekit/core/external_api_task.py @@ -13,22 +13,25 @@ from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin T = TypeVar("T") -DISPATCHER_MODULE = "dispatcher_module" -DISPATCHER_NAME = "dispatcher_name" -DISPATCHER_CONFIG_PKL = "dispatcher_config_pkl" -INPUTS = "inputs" +TASK_MODULE = "task_module" +TASK_NAME = "task_name" +TASK_CONFIG_PKL = "task_config_pkl" +TASK_TYPE = "api_task" +USE_SYNC_PLUGIN = "use_sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used ro run this task -class BaseDispatcher(AsyncAgentExecutorMixin, PythonTask): +class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): """ - TODO: Write the docstring + Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a + response from an external service. When the response is received, the task will complete. External API tasks are + designed to be run by the flyte agent. """ def __init__( self, name: str, - dispatcher_config: Optional[T] = None, - task_type: str = "dispatcher", + config: Optional[T] = None, + task_type: str = TASK_TYPE, return_type: Optional[T] = None, **kwargs, ): @@ -48,17 +51,21 @@ def __init__( interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) - self._dispatcher_config = dispatcher_config + self._config = config @abstractmethod - async def async_do(self, **kwargs) -> DoTaskResponse: + async def do(self, **kwargs) -> DoTaskResponse: + """ + Initiate an HTTP request to an external service such as OpenAI or Vertex AI and retrieve the response. + """ raise NotImplementedError def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { - DISPATCHER_MODULE: type(self).__module__, - DISPATCHER_NAME: type(self).__name__, + TASK_MODULE: type(self).__module__, + TASK_NAME: type(self).__name__, + USE_SYNC_PLUGIN: True, } - if self._dispatcher_config is not None: - cfg[DISPATCHER_CONFIG_PKL] = jsonpickle.encode(self._dispatcher_config) + if self._config is not None: + cfg[TASK_CONFIG_PKL] = jsonpickle.encode(self._config) return cfg diff --git a/flytekit/dispatcher/__init__.py b/flytekit/dispatcher/__init__.py deleted file mode 100644 index 7d2a507dbf..0000000000 --- a/flytekit/dispatcher/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .base_dispatcher import BaseDispatcher -from .dispatcher_engine import DispatcherEngine diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2c45050966..7100f24371 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -6,7 +6,7 @@ from abc import ABC from collections import OrderedDict from functools import partial -from types import FrameType +from types import FrameType, coroutine import grpc from flyteidl.admin.agent_pb2 import ( @@ -18,20 +18,18 @@ DeleteTaskResponse, DoTaskResponse, GetTaskResponse, - Resource, State, ) -from flyteidl.core import literals_pb2 from flyteidl.core.tasks_pb2 import TaskTemplate from rich.progress import Progress import flytekit from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings -from flytekit.core import utils from flytekit.core.base_task import PythonTask from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.system import FlyteAgentNotFound +from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.models.literals import LiteralMap @@ -88,7 +86,6 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT def do( self, context: grpc.ServicerContext, - output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -128,7 +125,6 @@ async def async_do( context: grpc.ServicerContext, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - output_prefix: typing.Optional[str] = None, ) -> DoTaskResponse: """ Return the result of executing a task. It should return error code if the task creation failed. @@ -136,33 +132,6 @@ async def async_do( raise NotImplementedError -class DispatcherAgent(AgentBase): - def __init__(self): - super().__init__(task_type="dispatcher") - - async def _do( - self, - entity: PythonTask, - task_template: TaskTemplate, - agent: AgentBase, - inputs: typing.Dict[str, typing.Any] = None, - ) -> DoTaskResponse: - ctx = FlyteContext.current_context() - grpc_ctx = _get_grpc_context() - - literals = {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) - inputs = LiteralMap(literals) if literals else None - output_prefix = ctx.file_access.get_random_local_directory() - - progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) - with progress: - progress.start_task(task) - return await agent.async_do(context=grpc_ctx, output_prefix=output_prefix, task_template=task_template, inputs=inputs) - - class AgentRegistry(object): """ This is the registry for all agents. The agent service will look up the agent @@ -190,7 +159,7 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed", "timedout", "canceled"]: + if state in ["failed", "timeout", "canceled"]: return RETRYABLE_FAILURE elif state in ["done", "succeeded", "success"]: return SUCCEEDED @@ -210,15 +179,24 @@ def get_agent_secret(secret_key: str) -> str: return flytekit.current_context().secrets.get(secret_key) +def _get_grpc_context() -> grpc.ServicerContext: + from unittest.mock import MagicMock + + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + return grpc_ctx + + class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. """ - _is_canceled = None - _agent = None - _entity = None + _clean_up_task: coroutine = None + _agent: AgentBase = None + _entity: PythonTask = None + _ctx: FlyteContext = FlyteContext.current_context() + _grpc_ctx: grpc.ServicerContext = _get_grpc_context() def execute(self, **kwargs) -> typing.Any: from flytekit.tools.translator import get_serializable @@ -227,8 +205,8 @@ def execute(self, **kwargs) -> typing.Any: task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template self._agent = AgentRegistry.get_agent(task_template.type) - if isinstance(self._agent, DispatcherAgent): - res = asyncio.run(self._agent._do(self._entity, task_template, self._agent, kwargs)) + if isinstance(self._agent, TaskExecutor): + res = asyncio.run(self._do(task_template, kwargs)) else: res = asyncio.run(self._create(task_template, kwargs)) res = asyncio.run(self._get(resource_meta=res.resource_meta)) @@ -241,27 +219,19 @@ def execute(self, **kwargs) -> typing.Any: async def _create( self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None ) -> CreateTaskResponse: - ctx = FlyteContext.current_context() - grpc_ctx = _get_grpc_context() - - # Convert python inputs to literals - literals = {} - for k, v in inputs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) - inputs = LiteralMap(literals) if literals else None - output_prefix = ctx.file_access.get_random_local_directory() + inputs = self.get_input_literal_map(inputs) + output_prefix = self._ctx.file_access.get_random_local_directory() if self._agent.asynchronous: - res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) + res = await self._agent.async_create(self._grpc_ctx, output_prefix, task_template, inputs) else: - res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) + res = self._agent.create(self._grpc_ctx, output_prefix, task_template, inputs) signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore return res async def _get(self, resource_meta: bytes) -> GetTaskResponse: state = RUNNING - grpc_ctx = _get_grpc_context() progress = Progress(transient=True) task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) @@ -270,28 +240,35 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: progress.start_task(task) time.sleep(1) if self._agent.asynchronous: - res = await self._agent.async_get(grpc_ctx, resource_meta) - if self._is_canceled: - await self._is_canceled + res = await self._agent.async_get(self._grpc_ctx, resource_meta) + if self._clean_up_task: + await self._clean_up_task sys.exit(1) else: - res = self._agent.get(grpc_ctx, resource_meta) + res = self._agent.get(self._grpc_ctx, resource_meta) state = res.resource.state logger.info(f"Task state: {state}") return res + async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): + inputs = self.get_input_literal_map(inputs) + if self._agent.asynchronous: + res = self._agent.async_do(self._grpc_ctx, task_template, inputs) + else: + res = self._agent.do(self._grpc_ctx, task_template, inputs) + return await res + def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: - grpc_ctx = _get_grpc_context() if self._agent.asynchronous: - if self._is_canceled is None: - self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta)) + if self._clean_up_task is None: + self._clean_up_task = asyncio.create_task(self._agent.async_delete(self._grpc_ctx, resource_meta)) else: - self._agent.delete(grpc_ctx, resource_meta) + self._agent.delete(self._grpc_ctx, resource_meta) sys.exit(1) - -def _get_grpc_context() -> grpc.ServicerContext: - from unittest.mock import MagicMock - - grpc_ctx = MagicMock(spec=grpc.ServicerContext) - return grpc_ctx + def get_input_literal_map(self, inputs: typing.Dict[str, typing.Any] = None) -> typing.Optional[LiteralMap]: + # Convert python inputs to literals + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(self._ctx, v, type(v), self._entity.interface.inputs[k].type) + return LiteralMap(literals) if literals else None diff --git a/flytekit/dispatcher/dispatcher_engine.py b/flytekit/extend/backend/task_executor.py similarity index 54% rename from flytekit/dispatcher/dispatcher_engine.py rename to flytekit/extend/backend/task_executor.py index c45e4a35bf..d2d8a6cd45 100644 --- a/flytekit/dispatcher/dispatcher_engine.py +++ b/flytekit/extend/backend/task_executor.py @@ -7,16 +7,19 @@ from flyteidl.admin.agent_pb2 import DoTaskResponse from flytekit import FlyteContextManager +from flytekit.core.external_api_task import TASK_CONFIG_PKL, TASK_MODULE, TASK_NAME, TASK_TYPE from flytekit.core.type_engine import TypeEngine -from flytekit.dispatcher.base_dispatcher import DISPATCHER_CONFIG_PKL, DISPATCHER_MODULE, DISPATCHER_NAME, INPUTS -from flytekit.extend.backend.base_agent import AgentRegistry, DispatcherAgent +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate T = typing.TypeVar("T") -class DispatcherEngine(DispatcherAgent): +class TaskExecutor(AgentBase): + def __init__(self): + super().__init__(task_type=TASK_TYPE, asynchronous=True) + async def async_do( self, context: grpc.ServicerContext, @@ -28,19 +31,16 @@ async def async_do( name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } ctx = FlyteContextManager.current_context() + native_inputs = {} if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) - task_template.custom[INPUTS] = native_inputs - else: - raise ValueError("Dispatcher needs a input!") meta = task_template.custom - dispatcher_module = importlib.import_module(name=meta[DISPATCHER_MODULE]) - dispatcher_def = getattr(dispatcher_module, meta[DISPATCHER_NAME]) - dispatcher_config = jsonpickle.decode(meta[DISPATCHER_CONFIG_PKL]) if meta.get(DISPATCHER_CONFIG_PKL) else None - inputs = meta.get(INPUTS, {}) - return await dispatcher_def("dispatcher", config=dispatcher_config).do(**inputs) + task_module = importlib.import_module(name=meta[TASK_MODULE]) + task_def = getattr(task_module, meta[TASK_NAME]) + config = jsonpickle.decode(meta[TASK_CONFIG_PKL]) if meta.get(TASK_CONFIG_PKL) else None + return await task_def(TASK_TYPE, config=config).do(**native_inputs) -AgentRegistry.register(DispatcherEngine()) +AgentRegistry.register(TaskExecutor()) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index a49a89d9f1..27e75b82b4 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,17 +1,16 @@ import json from typing import Any, Dict -import aiohttp from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager from flytekit.core.type_engine import TypeEngine +from flytekit.core.external_api_task import ExternalApiTask from flytekit.extend.backend.base_agent import get_agent_secret from flytekit.models.literals import LiteralMap -from flytekit.dispatcher.base_dispatcher import BaseDispatcher -class ChatGPTDispatcher(BaseDispatcher): +class ChatGPTTask(ExternalApiTask): """ TODO: Write the docstring """ From 4adf029bb806ef7ccfbf6e15f9ec4fd2d299107a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sun, 15 Oct 2023 00:38:13 -0700 Subject: [PATCH 16/64] refactor Signed-off-by: Kevin Su --- flytekit/__init__.py | 2 +- .../flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 96dc7b8ac1..7235506412 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -228,7 +228,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extend.backend.task_executor import DispatcherEngine +from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.image_spec import ImageSpec from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger from flytekit.models.common import Annotations, AuthRole, Labels diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py index ca0bb68c6f..ed6a8fc8b5 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -10,4 +10,4 @@ ChatGPTDispatcher """ -from .task import ChatGPTDispatcher +from .task import ChatGPTTask From 420cbf572498c6d2c912307891c4b090f635986a Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sun, 15 Oct 2023 17:26:39 +0800 Subject: [PATCH 17/64] chatgpt sync plugin succeed v1 Signed-off-by: Future Outlier --- flytekit/core/external_api_task.py | 6 +++--- flytekit/extend/backend/base_agent.py | 16 ++++++---------- flytekit/tools/translator.py | 4 +++- .../flytekitplugins/chatgpt/__init__.py | 2 +- .../flytekitplugins/chatgpt/task.py | 4 ++-- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 6d07ae7cff..33979dcb69 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -32,13 +32,13 @@ def __init__( name: str, config: Optional[T] = None, task_type: str = TASK_TYPE, - return_type: Optional[T] = None, + return_type: Optional[Any] = None, **kwargs, ): type_hints = get_type_hints(self.do, include_extras=True) signature = inspect.signature(self.do) inputs = collections.OrderedDict() - outputs = collections.OrderedDict({"o0": return_type}) + outputs = collections.OrderedDict({"o0": return_type}) if return_type else collections.OrderedDict() for k, _ in signature.parameters.items(): # type: ignore annotation = type_hints.get(k, None) @@ -47,7 +47,7 @@ def __init__( super().__init__( task_type=task_type, name=name, - task_config=None, + task_config=config, interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index d735d92426..6e34dad762 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -23,12 +23,11 @@ from flyteidl.core.tasks_pb2 import TaskTemplate import flytekit -from flytekit import FlyteContext, logger +from flytekit import FlyteContext from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.type_engine import TypeEngine from flytekit.exceptions.system import FlyteAgentNotFound -from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.exceptions.user import FlyteUserException from flytekit.models.literals import LiteralMap @@ -145,7 +144,6 @@ def register(agent: AgentBase): if agent.task_type in AgentRegistry._REGISTRY: raise ValueError(f"Duplicate agent for task type {agent.task_type}") AgentRegistry._REGISTRY[agent.task_type] = agent - logger.info(f"Registering an agent for task type {agent.task_type}") @staticmethod def get_agent(task_type: str) -> typing.Optional[AgentBase]: @@ -199,6 +197,7 @@ class AsyncAgentExecutorMixin: _grpc_ctx: grpc.ServicerContext = _get_grpc_context() def execute(self, **kwargs) -> typing.Any: + from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.tools.translator import get_serializable self._entity = typing.cast(PythonTask, self) @@ -232,19 +231,16 @@ async def _create( async def _get(self, resource_meta: bytes) -> GetTaskResponse: state = RUNNING - grpc_ctx = _get_grpc_context() - while not is_terminal_state(state): time.sleep(1) if self._agent.asynchronous: - res = await self._agent.async_get(grpc_ctx, resource_meta) - if self._is_canceled: - await self._is_canceled + res = await self._agent.async_get(self._grpc_ctx, resource_meta) + if self._clean_up_task: + await self._clean_up_task sys.exit(1) else: - res = self._agent.get(grpc_ctx, resource_meta) + res = self._agent.get(self._grpc_ctx, resource_meta) state = res.resource.state - logger.info(f"Task state: {state}") return res async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 87ccd2f534..7619bb4d7a 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask, SourceCode +from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.array_node_map_task import ArrayNodeMapTask @@ -728,6 +728,8 @@ def get_serializable( raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): + from flytekit import SourceCode + # 1. Check if the size of long description exceeds 16KB # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity if entity.docs and entity.docs.long_description: diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py index ed6a8fc8b5..7a47fd2ffb 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -7,7 +7,7 @@ :template: custom.rst :toctree: generated/ - ChatGPTDispatcher + ChatGPTTask """ from .task import ChatGPTTask diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 27e75b82b4..e2401b9b2b 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -4,8 +4,8 @@ from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine from flytekit.core.external_api_task import ExternalApiTask +from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import get_agent_secret from flytekit.models.literals import LiteralMap @@ -20,7 +20,7 @@ class ChatGPTTask(ExternalApiTask): # TODO, Add Value Error def __init__(self, name: str, config: Dict[str, Any], **kwargs): - super().__init__(name=name, dispatcher_config=config, return_type=str, **kwargs) + super().__init__(name=name, config=config, return_type=str, **kwargs) self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] From 120826872fd3a19a45fe0649a37ca7bed1bf7305 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 12:40:15 +0800 Subject: [PATCH 18/64] chatgpt syncplugin succeed Signed-off-by: Future Outlier --- flytekit/core/base_task.py | 9 +++++++-- flytekit/core/external_api_task.py | 9 ++++++--- .../flytekitplugins/chatgpt/task.py | 3 +-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dc46e6bc4f..a7a0aa09df 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -103,6 +103,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None + use_sync_plugin: bool = False def __post_init__(self): if self.timeout: @@ -128,7 +129,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: return _task_model.TaskMetadata( discoverable=self.cache, runtime=_task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python" + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "sync_plugin" if self.use_sync_plugin else "python" ), timeout=self.timeout, retries=self.retry_strategy, @@ -168,12 +169,13 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, + use_sync_plugin: bool = False, **kwargs, ): self._task_type = task_type self._name = name self._interface = interface - self._metadata = metadata if metadata else TaskMetadata() + self._metadata = metadata if metadata else TaskMetadata(use_sync_plugin=use_sync_plugin) self._task_type_version = task_type_version self._security_ctx = security_ctx self._docs = docs @@ -410,6 +412,7 @@ def __init__( interface: Optional[Interface] = None, environment: Optional[Dict[str, str]] = None, disable_deck: bool = True, + use_sync_plugin: bool = False, **kwargs, ): """ @@ -424,11 +427,13 @@ def __init__( environment (Optional[Dict[str, str]]): Any environment variables that should be supplied during the execution of the task. Supplied as a dictionary of key/value pairs disable_deck (bool): If true, this task will not output deck html file + use_sync_plugin (bool): If true, this task will invoke sync plugin in flytepropeller and flyteplugin """ super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), + use_sync_plugin=use_sync_plugin, **kwargs, ) self._python_interface = interface if interface else Interface() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 33979dcb69..571ed81f6d 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -8,7 +8,7 @@ from typing_extensions import get_type_hints from flytekit.configuration import SerializationSettings -from flytekit.core.base_task import PythonTask +from flytekit.core.base_task import PythonTask, TaskMetadata from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @@ -17,7 +17,7 @@ TASK_NAME = "task_name" TASK_CONFIG_PKL = "task_config_pkl" TASK_TYPE = "api_task" -USE_SYNC_PLUGIN = "use_sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used ro run this task +USE_SYNC_PLUGIN = "use_sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): @@ -49,8 +49,10 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), + use_sync_plugin=True, **kwargs, ) + self._config = config @abstractmethod @@ -64,8 +66,9 @@ def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { TASK_MODULE: type(self).__module__, TASK_NAME: type(self).__name__, - USE_SYNC_PLUGIN: True, } + if self._config is not None: cfg[TASK_CONFIG_PKL] = jsonpickle.encode(self._config) + return cfg diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index e2401b9b2b..75c03464bf 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -60,9 +60,8 @@ async def do( def get_header(openai_organization: str): - token = get_agent_secret(secret_key="OPENAI_ACCESS_TOKEN") return { "OpenAI-Organization": openai_organization, - "Authorization": f"Bearer {token}", + "Authorization": f"Bearer {get_agent_secret(secret_key='OPENAI_ACCESS_TOKEN')}", "content-type": "application/json", } From a64f6ee8e7d3ae383444f2894285a3053b3ba490 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 14:05:12 +0800 Subject: [PATCH 19/64] remove unused task metadata Signed-off-by: Future Outlier --- flytekit/core/base_task.py | 4 +++- flytekit/core/external_api_task.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index a7a0aa09df..9c9607e071 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -129,7 +129,9 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: return _task_model.TaskMetadata( discoverable=self.cache, runtime=_task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "sync_plugin" if self.use_sync_plugin else "python" + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "sync_plugin" if self.use_sync_plugin else "python", ), timeout=self.timeout, retries=self.retry_strategy, diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 571ed81f6d..2799f3ae74 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -8,7 +8,7 @@ from typing_extensions import get_type_hints from flytekit.configuration import SerializationSettings -from flytekit.core.base_task import PythonTask, TaskMetadata +from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin From 815dde4a284bd0302cbe2c9d60381ceca68b2f38 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 15:54:09 +0800 Subject: [PATCH 20/64] add base_task RuntimeMetadata test and push for gitsha Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 6 +-- .../unit/core/test_external_api_task.py | 0 .../flytekit/unit/core/test_task_metadata.py | 49 +++++++++++++++++++ tests/flytekit/unit/extend/test_agent.py | 3 ++ 4 files changed, 55 insertions(+), 3 deletions(-) create mode 100644 tests/flytekit/unit/core/test_external_api_task.py create mode 100644 tests/flytekit/unit/core/test_task_metadata.py diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 6e34dad762..5adb18ff79 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -89,7 +89,7 @@ def do( inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: """ - Return the result of executing a task. It should return error code if the task creation failed. + Return the result of executing a task. It should return error code if the task execution failed. """ raise NotImplementedError @@ -126,7 +126,7 @@ async def async_do( inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: """ - Return the result of executing a task. It should return error code if the task creation failed. + Return the result of executing a task. It should return error code if the task execution failed. """ raise NotImplementedError @@ -157,7 +157,7 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed", "timeout", "canceled"]: + if state in ["failed", "timeout", "timedout", "canceled"]: return RETRYABLE_FAILURE elif state in ["done", "succeeded", "success"]: return SUCCEEDED diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py new file mode 100644 index 0000000000..fe9cc39d7d --- /dev/null +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -0,0 +1,49 @@ +from flytekit.core.base_task import TaskMetadata +import datetime +from flytekit.models import task as _task_model +from flytekit.models import literals as _literal_models +import pytest +from flytekit import __version__ + +def test_post_init_conditions(): + with pytest.raises(ValueError, match="Caching is enabled ``cache=True`` but ``cache_version`` is not set."): + TaskMetadata(cache=True, cache_version="") + + with pytest.raises(ValueError, match="Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled."): + TaskMetadata(cache=False, cache_serialize=True) + + with pytest.raises(ValueError, match="timeout should be duration represented as either a datetime.timedelta or int seconds"): + TaskMetadata(timeout="invalid_timeout") + + tm = TaskMetadata(timeout=3600) + assert isinstance(tm.timeout, datetime.timedelta) + +def test_retry_strategy(): + tm = TaskMetadata(retries=5) + assert tm.retry_strategy.retries == 5 + +def test_to_taskmetadata_model(): + tm = TaskMetadata(cache=True, + cache_serialize=True, + cache_version="v1", + interruptible=True, + deprecated="TEST DEPRECATED ERROR MESSAGE", + retries=3, + timeout=3600, + pod_template_name="TEST POD TEMPLATE NAME", + use_sync_plugin=True,) + model = tm.to_taskmetadata_model() + + assert model.discoverable == True + assert model.runtime == _task_model.RuntimeMetadata( + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "sync_plugin", + ) + assert model.retries == _literal_models.RetryStrategy(3) + assert model.timeout == datetime.timedelta(seconds=3600) + assert model.interruptible == True + assert model.discovery_version == "v1" + assert model.deprecated_error_message == "TEST DEPRECATED ERROR MESSAGE" + assert model.cache_serializable == True + assert model.pod_template_name == "TEST POD TEMPLATE NAME" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index e9555b2026..b40499ac64 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -230,3 +230,6 @@ def test_convert_to_flyte_state(): def test_get_agent_secret(mocked_context): mocked_context.return_value.secrets.get.return_value = "mocked token" assert get_agent_secret("mocked key") == "mocked token" + +# TODO: TEST TASK EXECUTOR IN HERE + From 9f3072e2fc2f9054308e5fd6982892684cadd2bb Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 16:29:08 +0800 Subject: [PATCH 21/64] fix ExternalApiTask import Signed-off-by: Future Outlier --- flytekit/__init__.py | 1 + .../unit/core/test_external_api_task.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 7235506412..05c5f054c2 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -212,6 +212,7 @@ from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic +from flytekit.core.external_api_task import ExternalApiTask from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py index e69de29bb2..43ef41dab7 100644 --- a/tests/flytekit/unit/core/test_external_api_task.py +++ b/tests/flytekit/unit/core/test_external_api_task.py @@ -0,0 +1,31 @@ +import json +import collections +from flytekit.core.external_api_task import ExternalApiTask, TASK_MODULE, TASK_NAME, TASK_CONFIG_PKL # replace "your_module" with the actual module name where ExternalApiTask is defined + + +# Mocking ExternalApiTask to make it instantiable +class MockExternalApiTask(ExternalApiTask): + + async def do(self, **kwargs): + pass + + +# Test for the __init__ method +def test_init(): + task = MockExternalApiTask(name="test_task", return_type=str) + assert task.name == "test_task" + assert task.interface.inputs == collections.OrderedDict() + assert task.interface.outputs == collections.OrderedDict({"o0": str}) + +# Test for the get_custom method +def test_get_custom(): + task = MockExternalApiTask(name="test_task", config={"key": "value"}) + custom = task.get_custom() + + expected_config = json.loads('{"key": "value"}') # replace with the expected serialized config + assert custom[TASK_MODULE] == MockExternalApiTask.__module__ + assert custom[TASK_NAME] == MockExternalApiTask.__name__ + assert json.loads(custom[TASK_CONFIG_PKL]) == expected_config # you might need to adjust this depending on how you expect the config to be serialized + + +# Run this with `pytest test_external_api_task.py` From 713db18d0dec693d680b2f5b3bd6b8bd6f5daa54 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 16:41:47 +0800 Subject: [PATCH 22/64] add for flyteidl version remove for building image Signed-off-by: Future Outlier --- setup.py | 2 +- .../unit/core/test_external_api_task.py | 27 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index 6828cb8661..45ca37f16e 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.16", + # "flyteidl>=1.5.16", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py index 43ef41dab7..c337ff8140 100644 --- a/tests/flytekit/unit/core/test_external_api_task.py +++ b/tests/flytekit/unit/core/test_external_api_task.py @@ -1,31 +1,30 @@ import json import collections -from flytekit.core.external_api_task import ExternalApiTask, TASK_MODULE, TASK_NAME, TASK_CONFIG_PKL # replace "your_module" with the actual module name where ExternalApiTask is defined +from flytekit.core.external_api_task import ExternalApiTask, TASK_MODULE, TASK_NAME, TASK_CONFIG_PKL -# Mocking ExternalApiTask to make it instantiable class MockExternalApiTask(ExternalApiTask): - async def do(self, **kwargs): - pass + async def do(self, test_int_input : int, **kwargs) -> int: + return test_int_input - -# Test for the __init__ method def test_init(): - task = MockExternalApiTask(name="test_task", return_type=str) + task = MockExternalApiTask(name="test_task", return_type=int) assert task.name == "test_task" - assert task.interface.inputs == collections.OrderedDict() - assert task.interface.outputs == collections.OrderedDict({"o0": str}) + assert task.interface.inputs == collections.OrderedDict({"test_int_input": int}) + assert task.interface.outputs == collections.OrderedDict({"o0": int}) + +# use asyncio +def test_do(): + task = MockExternalApiTask(name="test_task", return_type=str) + assert task.interface.inputs == collections.OrderedDict({"test_int_input": int}) -# Test for the get_custom method def test_get_custom(): task = MockExternalApiTask(name="test_task", config={"key": "value"}) custom = task.get_custom() - expected_config = json.loads('{"key": "value"}') # replace with the expected serialized config + expected_config = json.loads('{"key": "value"}') assert custom[TASK_MODULE] == MockExternalApiTask.__module__ assert custom[TASK_NAME] == MockExternalApiTask.__name__ - assert json.loads(custom[TASK_CONFIG_PKL]) == expected_config # you might need to adjust this depending on how you expect the config to be serialized - + assert json.loads(custom[TASK_CONFIG_PKL]) == expected_config -# Run this with `pytest test_external_api_task.py` From 0ecebdfc3a018d888748e4d9adf715c72c761180 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 16:56:45 +0800 Subject: [PATCH 23/64] ciruclar import Signed-off-by: Future Outlier --- flytekit/tools/translator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index 7619bb4d7a..3e97ad7892 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union -from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.array_node_map_task import ArrayNodeMapTask @@ -163,6 +162,8 @@ def get_serializable_task( settings: SerializationSettings, entity: FlyteLocalEntity, ) -> TaskSpec: + from flytekit import PythonFunctionTask + task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, From a058734046092c33a9a93c4951ef76fd61ca408e Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 16 Oct 2023 18:00:13 +0800 Subject: [PATCH 24/64] add tests and lints Signed-off-by: Future Outlier --- .../flytekitplugins/chatgpt/task.py | 19 ++++--- setup.py | 2 +- .../unit/core/test_external_api_task.py | 32 +++++++---- .../flytekit/unit/core/test_task_metadata.py | 53 +++++++++++-------- tests/flytekit/unit/extend/test_agent.py | 49 +++++++++++++++-- 5 files changed, 107 insertions(+), 48 deletions(-) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 75c03464bf..7b7700b9da 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,6 +1,7 @@ import json from typing import Any, Dict +import aiohttp from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager @@ -32,17 +33,15 @@ async def do( openai_url = "https://api.openai.com/v1/chat/completions" data = json.dumps(self._chatgpt_conf) - message = "TEST SYNC PLUGIN" + async with aiohttp.ClientSession() as session: + async with session.post( + openai_url, headers=get_header(openai_organization=self._openai_organization), data=data + ) as resp: + if resp.status != 200: + raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") + response = await resp.json() - # async with aiohttp.ClientSession() as session: - # async with session.post( - # openai_url, headers=get_header(openai_organization=self._openai_organization), data=data - # ) as resp: - # if resp.status != 200: - # raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") - # response = await resp.json() - - # message = response["choices"][0]["message"]["content"] + message = response["choices"][0]["message"]["content"] ctx = FlyteContextManager.current_context() outputs = LiteralMap( diff --git a/setup.py b/setup.py index 45ca37f16e..6828cb8661 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - # "flyteidl>=1.5.16", + "flyteidl>=1.5.16", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py index c337ff8140..dde494c822 100644 --- a/tests/flytekit/unit/core/test_external_api_task.py +++ b/tests/flytekit/unit/core/test_external_api_task.py @@ -1,23 +1,34 @@ -import json import collections -from flytekit.core.external_api_task import ExternalApiTask, TASK_MODULE, TASK_NAME, TASK_CONFIG_PKL +import json +import pytest -class MockExternalApiTask(ExternalApiTask): +from flytekit.core.external_api_task import TASK_CONFIG_PKL, TASK_MODULE, TASK_NAME, ExternalApiTask +from flytekit.core.interface import Interface, transform_interface_to_typed_interface - async def do(self, test_int_input : int, **kwargs) -> int: + +class MockExternalApiTask(ExternalApiTask): + async def do(self, test_int_input: int, **kwargs) -> int: return test_int_input + def test_init(): task = MockExternalApiTask(name="test_task", return_type=int) assert task.name == "test_task" - assert task.interface.inputs == collections.OrderedDict({"test_int_input": int}) - assert task.interface.outputs == collections.OrderedDict({"o0": int}) -# use asyncio -def test_do(): - task = MockExternalApiTask(name="test_task", return_type=str) - assert task.interface.inputs == collections.OrderedDict({"test_int_input": int}) + interface = Interface( + inputs=collections.OrderedDict({"test_int_input": int, "kwargs": None}), + outputs=collections.OrderedDict({"o0": int}), + ) + assert task.interface == transform_interface_to_typed_interface(interface) + + +@pytest.mark.asyncio +async def test_do(): + input_num = 100 + task = MockExternalApiTask(name="test_task", return_type=int) + assert input_num == await task.do(test_int_input=input_num) + def test_get_custom(): task = MockExternalApiTask(name="test_task", config={"key": "value"}) @@ -27,4 +38,3 @@ def test_get_custom(): assert custom[TASK_MODULE] == MockExternalApiTask.__module__ assert custom[TASK_NAME] == MockExternalApiTask.__name__ assert json.loads(custom[TASK_CONFIG_PKL]) == expected_config - diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index fe9cc39d7d..003753dbfc 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -1,49 +1,60 @@ -from flytekit.core.base_task import TaskMetadata import datetime -from flytekit.models import task as _task_model -from flytekit.models import literals as _literal_models + import pytest + from flytekit import __version__ +from flytekit.core.base_task import TaskMetadata +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_model + def test_post_init_conditions(): with pytest.raises(ValueError, match="Caching is enabled ``cache=True`` but ``cache_version`` is not set."): TaskMetadata(cache=True, cache_version="") - with pytest.raises(ValueError, match="Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled."): + with pytest.raises( + ValueError, match="Cache serialize is enabled ``cache_serialize=True`` but ``cache`` is not enabled." + ): TaskMetadata(cache=False, cache_serialize=True) - with pytest.raises(ValueError, match="timeout should be duration represented as either a datetime.timedelta or int seconds"): + with pytest.raises( + ValueError, match="timeout should be duration represented as either a datetime.timedelta or int seconds" + ): TaskMetadata(timeout="invalid_timeout") tm = TaskMetadata(timeout=3600) assert isinstance(tm.timeout, datetime.timedelta) + def test_retry_strategy(): tm = TaskMetadata(retries=5) assert tm.retry_strategy.retries == 5 + def test_to_taskmetadata_model(): - tm = TaskMetadata(cache=True, - cache_serialize=True, - cache_version="v1", - interruptible=True, - deprecated="TEST DEPRECATED ERROR MESSAGE", - retries=3, - timeout=3600, - pod_template_name="TEST POD TEMPLATE NAME", - use_sync_plugin=True,) + tm = TaskMetadata( + cache=True, + cache_serialize=True, + cache_version="v1", + interruptible=True, + deprecated="TEST DEPRECATED ERROR MESSAGE", + retries=3, + timeout=3600, + pod_template_name="TEST POD TEMPLATE NAME", + use_sync_plugin=True, + ) model = tm.to_taskmetadata_model() - assert model.discoverable == True + assert model.discoverable is True assert model.runtime == _task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "sync_plugin", - ) + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "sync_plugin", + ) assert model.retries == _literal_models.RetryStrategy(3) assert model.timeout == datetime.timedelta(seconds=3600) - assert model.interruptible == True + assert model.interruptible is True assert model.discovery_version == "v1" assert model.deprecated_error_message == "TEST DEPRECATED ERROR MESSAGE" - assert model.cache_serializable == True + assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index b40499ac64..12f0768151 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -16,6 +16,8 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, + DoTaskRequest, + DoTaskResponse, GetTaskRequest, GetTaskResponse, Resource, @@ -65,6 +67,15 @@ def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskRes def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() + def do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> DoTaskResponse: + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) + class AsyncDummyAgent(AgentBase): def __init__(self): @@ -85,6 +96,15 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() + async def async_do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> DoTaskResponse: + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) + AgentRegistry.register(DummyAgent()) AgentRegistry.register(AsyncDummyAgent()) @@ -139,6 +159,7 @@ def test_dummy_agent(): assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() + assert agent.do(ctx, "/tmp", dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): @@ -166,34 +187,46 @@ async def test_async_dummy_agent(): assert res.resource.state == SUCCEEDED res = await agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse() + res = await agent.async_do(ctx, "/tmp", async_dummy_template, task_inputs) + assert res == DoTaskResponse(resource=Resource(state=SUCCEEDED)) @pytest.mark.asyncio async def run_agent_server(): service = AsyncAgentService() ctx = MagicMock(spec=grpc.ServicerContext) - request = CreateTaskRequest( + create_request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) + async_create_request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() + ) + do_request = DoTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() ) - async_request = CreateTaskRequest( + async_do_request = DoTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await service.CreateTask(request, ctx) + res = await service.CreateTask(create_request, ctx) assert res.resource_meta == metadata_bytes res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) assert isinstance(res, DeleteTaskResponse) + res = await service.DoTask(do_request, ctx) + assert res.resource.state == SUCCEEDED - res = await service.CreateTask(async_request, ctx) + res = await service.CreateTask(async_create_request, ctx) assert res.resource_meta == metadata_bytes res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert isinstance(res, DeleteTaskResponse) + res = await service.DoTask(async_do_request, ctx) + assert res.resource.state == SUCCEEDED res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) assert res is None @@ -205,7 +238,7 @@ def test_agent_server(): def test_is_terminal_state(): assert is_terminal_state(SUCCEEDED) - assert is_terminal_state(PERMANENT_FAILURE) + assert is_terminal_state(RETRYABLE_FAILURE) assert is_terminal_state(PERMANENT_FAILURE) assert not is_terminal_state(RUNNING) @@ -231,5 +264,11 @@ def test_get_agent_secret(mocked_context): mocked_context.return_value.secrets.get.return_value = "mocked token" assert get_agent_secret("mocked key") == "mocked token" + # TODO: TEST TASK EXECUTOR IN HERE +""" +refer the task up +t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") +t.execute() +""" From 719ae32d2d8c95df776d5751517600e2a4b99952 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 00:07:03 +0800 Subject: [PATCH 25/64] final all tests Signed-off-by: Future Outlier --- plugins/flytekit-chatgpt/dev-requirements.in | 2 + plugins/flytekit-chatgpt/dev-requirements.txt | 44 ++++++++++++ .../flytekitplugins/chatgpt/__init__.py | 0 .../flytekitplugins/chatgpt/task.py | 14 ++-- plugins/flytekit-chatgpt/requirements.in | 2 + .../requirements.txt | 0 .../setup.py | 2 +- .../tests/__init__.py | 0 .../tests/test_chatgpt_task.py | 71 +++++++++++++++++++ .../flytekit-openai-chatgpt/requirements.in | 2 - .../tests/test_chatgpt.py | 0 tests/flytekit/unit/extend/test_agent.py | 35 +++++++-- .../unit/extend/test_task_executor.py | 45 ++++++++++++ 13 files changed, 204 insertions(+), 13 deletions(-) create mode 100644 plugins/flytekit-chatgpt/dev-requirements.in create mode 100644 plugins/flytekit-chatgpt/dev-requirements.txt rename plugins/{flytekit-openai-chatgpt => flytekit-chatgpt}/flytekitplugins/chatgpt/__init__.py (100%) rename plugins/{flytekit-openai-chatgpt => flytekit-chatgpt}/flytekitplugins/chatgpt/task.py (81%) create mode 100644 plugins/flytekit-chatgpt/requirements.in rename plugins/{flytekit-openai-chatgpt => flytekit-chatgpt}/requirements.txt (100%) rename plugins/{flytekit-openai-chatgpt => flytekit-chatgpt}/setup.py (95%) rename plugins/{flytekit-openai-chatgpt => flytekit-chatgpt}/tests/__init__.py (100%) create mode 100644 plugins/flytekit-chatgpt/tests/test_chatgpt_task.py delete mode 100644 plugins/flytekit-openai-chatgpt/requirements.in delete mode 100644 plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py create mode 100644 tests/flytekit/unit/extend/test_task_executor.py diff --git a/plugins/flytekit-chatgpt/dev-requirements.in b/plugins/flytekit-chatgpt/dev-requirements.in new file mode 100644 index 0000000000..78d0eca127 --- /dev/null +++ b/plugins/flytekit-chatgpt/dev-requirements.in @@ -0,0 +1,2 @@ +aioresponses +pytest-asyncio diff --git a/plugins/flytekit-chatgpt/dev-requirements.txt b/plugins/flytekit-chatgpt/dev-requirements.txt new file mode 100644 index 0000000000..f8255ed178 --- /dev/null +++ b/plugins/flytekit-chatgpt/dev-requirements.txt @@ -0,0 +1,44 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +aiohttp==3.8.6 + # via aioresponses +aioresponses==0.7.4 + # via -r dev-requirements.in +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via aiohttp +charset-normalizer==3.3.0 + # via aiohttp +exceptiongroup==1.1.3 + # via pytest +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +idna==3.4 + # via yarl +iniconfig==2.0.0 + # via pytest +multidict==6.0.4 + # via + # aiohttp + # yarl +packaging==23.2 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.2 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest +yarl==1.9.2 + # via aiohttp diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py similarity index 100% rename from plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py rename to plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py similarity index 81% rename from plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py rename to plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py index 7b7700b9da..fd9d687a2b 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py @@ -13,18 +13,24 @@ class ChatGPTTask(ExternalApiTask): """ - TODO: Write the docstring + This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. """ _openai_organization: str = None _chatgpt_conf: Dict[str, Any] = None - # TODO, Add Value Error def __init__(self, name: str, config: Dict[str, Any], **kwargs): - super().__init__(name=name, config=config, return_type=str, **kwargs) + if "openai_organization" not in config: + raise ValueError("The 'openai_organization' configuration variable is required") + + if "chatgpt_conf" not in config: + raise ValueError("The 'chatgpt_conf' configuration variable is required") + self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] + super().__init__(name=name, config=config, return_type=str, **kwargs) + async def do( self, message: str = None, @@ -35,7 +41,7 @@ async def do( async with aiohttp.ClientSession() as session: async with session.post( - openai_url, headers=get_header(openai_organization=self._openai_organization), data=data + url=openai_url, headers=get_header(openai_organization=self._openai_organization), data=data ) as resp: if resp.status != 200: raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") diff --git a/plugins/flytekit-chatgpt/requirements.in b/plugins/flytekit-chatgpt/requirements.in new file mode 100644 index 0000000000..35f1aae56e --- /dev/null +++ b/plugins/flytekit-chatgpt/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-chatgpt diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-chatgpt/requirements.txt similarity index 100% rename from plugins/flytekit-openai-chatgpt/requirements.txt rename to plugins/flytekit-chatgpt/requirements.txt diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-chatgpt/setup.py similarity index 95% rename from plugins/flytekit-openai-chatgpt/setup.py rename to plugins/flytekit-chatgpt/setup.py index a85fbec0fa..da0cfc64c6 100644 --- a/plugins/flytekit-openai-chatgpt/setup.py +++ b/plugins/flytekit-chatgpt/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-openai-chatgpt/tests/__init__.py b/plugins/flytekit-chatgpt/tests/__init__.py similarity index 100% rename from plugins/flytekit-openai-chatgpt/tests/__init__.py rename to plugins/flytekit-chatgpt/tests/__init__.py diff --git a/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py b/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py new file mode 100644 index 0000000000..008f14d1bc --- /dev/null +++ b/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py @@ -0,0 +1,71 @@ +from unittest import mock + +import aioresponses +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from flytekitplugins.chatgpt import ChatGPTTask + +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import LiteralMap + + +@pytest.mark.asyncio +async def test_chatgpt_task(): + message = "TEST MESSAGE" + response_message = "Hello! How can I assist you today?" + organization = "TEST ORGANIZATION" + chatgpt_job = ChatGPTTask( + name="chatgpt", + config={ + "openai_organization": organization, + "chatgpt_conf": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "temperature": 0.7, + }, + }, + ) + ctx = FlyteContextManager.current_context() + + assert chatgpt_job._chatgpt_conf == { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "temperature": 0.7, + } + assert chatgpt_job._openai_organization == organization + + mocked_token = "mocked_chatgpt_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + openai_url = "https://api.openai.com/v1/chat/completions" + mock_do_response = { + "id": "chatcmpl-8AJDGV3GdDsTcdFJfc68OxuLqIJZr", + "object": "chat.completion", + "created": 1697467826, + "model": "gpt-3.5-turbo-0613", + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": response_message}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 8, "completion_tokens": 9, "total_tokens": 17}, + } + with aioresponses.aioresponses() as mocked: + mocked.post(openai_url, status=200, payload=mock_do_response) + res = await chatgpt_job.do(message=message) + + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs + == LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + response_message, + type(response_message), + TypeEngine.to_literal_type(type(response_message)), + ) + } + ).to_flyte_idl() + ) + + mock.patch.stopall() diff --git a/plugins/flytekit-openai-chatgpt/requirements.in b/plugins/flytekit-openai-chatgpt/requirements.in deleted file mode 100644 index 03afde6b3a..0000000000 --- a/plugins/flytekit-openai-chatgpt/requirements.in +++ /dev/null @@ -1,2 +0,0 @@ -. --e file:.#egg=flytekitplugins-openai-chatgpt diff --git a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 12f0768151..bf3197dfa9 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -265,10 +265,33 @@ def test_get_agent_secret(mocked_context): assert get_agent_secret("mocked key") == "mocked token" -# TODO: TEST TASK EXECUTOR IN HERE +def get_task_template(task_type: str) -> TaskTemplate: + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) -""" -refer the task up -t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") -t.execute() -""" + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(types.LiteralType(types.SimpleType.INTEGER), "description1"), + }, + {}, + ) + + return TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type=task_type, + custom={}, + ) diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py new file mode 100644 index 0000000000..51213cc4ff --- /dev/null +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -0,0 +1,45 @@ +import collections +from unittest.mock import MagicMock + +import grpc +import pytest + +from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, ExternalApiTask +from flytekit.core.interface import Interface, transform_interface_to_typed_interface +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.models import literals +from tests.flytekit.unit.extend.test_agent import get_task_template + + +class MockExternalApiTask(ExternalApiTask): + async def do(self, input: str, **kwargs) -> str: + return input + + +@pytest.mark.asyncio +async def test_task_executor_engine(): + input = "TASK INPUT" + + interface = Interface( + inputs=collections.OrderedDict({"input": str, "kwargs": None}), + outputs=collections.OrderedDict({"o0": str}), + ) + tmp = get_task_template("api_task") + tmp._custom = { + TASK_MODULE: MockExternalApiTask.__module__, + TASK_NAME: MockExternalApiTask.__name__, + } + + tmp._interface = transform_interface_to_typed_interface(interface) + + task_inputs = literals.LiteralMap( + { + "input": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(string_value="TASK INPUT"))), + }, + ) + + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("api_task") + + res = await agent.async_do(ctx, tmp, task_inputs) + assert res == input From 5823cb17c1ffd38440500d1357efdafb23552013 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 09:21:09 +0800 Subject: [PATCH 26/64] move to flytekit-openai-chatgpt dir and add setup.py python 3.11 test Signed-off-by: Future Outlier --- .../dev-requirements.in | 2 + .../dev-requirements.txt | 44 ++++++++++++ .../flytekitplugins/chatgpt/__init__.py | 13 ++++ .../flytekitplugins/chatgpt/task.py | 72 +++++++++++++++++++ .../flytekit-openai-chatgpt/requirements.in | 2 + .../flytekit-openai-chatgpt/requirements.txt | 0 plugins/flytekit-openai-chatgpt/setup.py | 37 ++++++++++ .../flytekit-openai-chatgpt/tests/__init__.py | 0 .../tests/test_chatgpt_task.py | 71 ++++++++++++++++++ 9 files changed, 241 insertions(+) create mode 100644 plugins/flytekit-openai-chatgpt/dev-requirements.in create mode 100644 plugins/flytekit-openai-chatgpt/dev-requirements.txt create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py create mode 100644 plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py create mode 100644 plugins/flytekit-openai-chatgpt/requirements.in create mode 100644 plugins/flytekit-openai-chatgpt/requirements.txt create mode 100644 plugins/flytekit-openai-chatgpt/setup.py create mode 100644 plugins/flytekit-openai-chatgpt/tests/__init__.py create mode 100644 plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.in b/plugins/flytekit-openai-chatgpt/dev-requirements.in new file mode 100644 index 0000000000..78d0eca127 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.in @@ -0,0 +1,2 @@ +aioresponses +pytest-asyncio diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.txt b/plugins/flytekit-openai-chatgpt/dev-requirements.txt new file mode 100644 index 0000000000..f8255ed178 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.txt @@ -0,0 +1,44 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +aiohttp==3.8.6 + # via aioresponses +aioresponses==0.7.4 + # via -r dev-requirements.in +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via aiohttp +charset-normalizer==3.3.0 + # via aiohttp +exceptiongroup==1.1.3 + # via pytest +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +idna==3.4 + # via yarl +iniconfig==2.0.0 + # via pytest +multidict==6.0.4 + # via + # aiohttp + # yarl +packaging==23.2 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.2 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest +yarl==1.9.2 + # via aiohttp diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py new file mode 100644 index 0000000000..7a47fd2ffb --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/__init__.py @@ -0,0 +1,13 @@ +""" +.. currentmodule:: flytekitplugins.chatgpt + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + ChatGPTTask +""" + +from .task import ChatGPTTask diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py new file mode 100644 index 0000000000..fd9d687a2b --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -0,0 +1,72 @@ +import json +from typing import Any, Dict + +import aiohttp +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource + +from flytekit import FlyteContextManager +from flytekit.core.external_api_task import ExternalApiTask +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import get_agent_secret +from flytekit.models.literals import LiteralMap + + +class ChatGPTTask(ExternalApiTask): + """ + This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. + """ + + _openai_organization: str = None + _chatgpt_conf: Dict[str, Any] = None + + def __init__(self, name: str, config: Dict[str, Any], **kwargs): + if "openai_organization" not in config: + raise ValueError("The 'openai_organization' configuration variable is required") + + if "chatgpt_conf" not in config: + raise ValueError("The 'chatgpt_conf' configuration variable is required") + + self._openai_organization = config["openai_organization"] + self._chatgpt_conf = config["chatgpt_conf"] + + super().__init__(name=name, config=config, return_type=str, **kwargs) + + async def do( + self, + message: str = None, + ) -> DoTaskResponse: + self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] + openai_url = "https://api.openai.com/v1/chat/completions" + data = json.dumps(self._chatgpt_conf) + + async with aiohttp.ClientSession() as session: + async with session.post( + url=openai_url, headers=get_header(openai_organization=self._openai_organization), data=data + ) as resp: + if resp.status != 200: + raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") + response = await resp.json() + + message = response["choices"][0]["message"]["content"] + + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + message, + type(message), + TypeEngine.to_literal_type(type(message)), + ) + } + ).to_flyte_idl() + + return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + + +def get_header(openai_organization: str): + return { + "OpenAI-Organization": openai_organization, + "Authorization": f"Bearer {get_agent_secret(secret_key='OPENAI_ACCESS_TOKEN')}", + "content-type": "application/json", + } diff --git a/plugins/flytekit-openai-chatgpt/requirements.in b/plugins/flytekit-openai-chatgpt/requirements.in new file mode 100644 index 0000000000..35f1aae56e --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-chatgpt diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py new file mode 100644 index 0000000000..401b0df201 --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -0,0 +1,37 @@ +from setuptools import setup + +PLUGIN_NAME = "chatgpt" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Bigquery plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-openai-chatgpt/tests/__init__.py b/plugins/flytekit-openai-chatgpt/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py new file mode 100644 index 0000000000..008f14d1bc --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py @@ -0,0 +1,71 @@ +from unittest import mock + +import aioresponses +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from flytekitplugins.chatgpt import ChatGPTTask + +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.models.literals import LiteralMap + + +@pytest.mark.asyncio +async def test_chatgpt_task(): + message = "TEST MESSAGE" + response_message = "Hello! How can I assist you today?" + organization = "TEST ORGANIZATION" + chatgpt_job = ChatGPTTask( + name="chatgpt", + config={ + "openai_organization": organization, + "chatgpt_conf": { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "temperature": 0.7, + }, + }, + ) + ctx = FlyteContextManager.current_context() + + assert chatgpt_job._chatgpt_conf == { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": message}], + "temperature": 0.7, + } + assert chatgpt_job._openai_organization == organization + + mocked_token = "mocked_chatgpt_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + openai_url = "https://api.openai.com/v1/chat/completions" + mock_do_response = { + "id": "chatcmpl-8AJDGV3GdDsTcdFJfc68OxuLqIJZr", + "object": "chat.completion", + "created": 1697467826, + "model": "gpt-3.5-turbo-0613", + "choices": [ + {"index": 0, "message": {"role": "assistant", "content": response_message}, "finish_reason": "stop"} + ], + "usage": {"prompt_tokens": 8, "completion_tokens": 9, "total_tokens": 17}, + } + with aioresponses.aioresponses() as mocked: + mocked.post(openai_url, status=200, payload=mock_do_response) + res = await chatgpt_job.do(message=message) + + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs + == LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + response_message, + type(response_message), + TypeEngine.to_literal_type(type(response_message)), + ) + } + ).to_flyte_idl() + ) + + mock.patch.stopall() From c6058fcbc284514b49335a04539daa133b98e3c8 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 11:58:19 +0800 Subject: [PATCH 27/64] delete flytekit-chatgpt Signed-off-by: Future Outlier --- plugins/flytekit-chatgpt/dev-requirements.in | 2 - plugins/flytekit-chatgpt/dev-requirements.txt | 44 ------------ .../flytekitplugins/chatgpt/__init__.py | 13 ---- .../flytekitplugins/chatgpt/task.py | 72 ------------------- plugins/flytekit-chatgpt/requirements.in | 2 - plugins/flytekit-chatgpt/requirements.txt | 0 plugins/flytekit-chatgpt/setup.py | 36 ---------- plugins/flytekit-chatgpt/tests/__init__.py | 0 .../tests/test_chatgpt_task.py | 71 ------------------ 9 files changed, 240 deletions(-) delete mode 100644 plugins/flytekit-chatgpt/dev-requirements.in delete mode 100644 plugins/flytekit-chatgpt/dev-requirements.txt delete mode 100644 plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py delete mode 100644 plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py delete mode 100644 plugins/flytekit-chatgpt/requirements.in delete mode 100644 plugins/flytekit-chatgpt/requirements.txt delete mode 100644 plugins/flytekit-chatgpt/setup.py delete mode 100644 plugins/flytekit-chatgpt/tests/__init__.py delete mode 100644 plugins/flytekit-chatgpt/tests/test_chatgpt_task.py diff --git a/plugins/flytekit-chatgpt/dev-requirements.in b/plugins/flytekit-chatgpt/dev-requirements.in deleted file mode 100644 index 78d0eca127..0000000000 --- a/plugins/flytekit-chatgpt/dev-requirements.in +++ /dev/null @@ -1,2 +0,0 @@ -aioresponses -pytest-asyncio diff --git a/plugins/flytekit-chatgpt/dev-requirements.txt b/plugins/flytekit-chatgpt/dev-requirements.txt deleted file mode 100644 index f8255ed178..0000000000 --- a/plugins/flytekit-chatgpt/dev-requirements.txt +++ /dev/null @@ -1,44 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile dev-requirements.in -# -aiohttp==3.8.6 - # via aioresponses -aioresponses==0.7.4 - # via -r dev-requirements.in -aiosignal==1.3.1 - # via aiohttp -async-timeout==4.0.3 - # via aiohttp -attrs==23.1.0 - # via aiohttp -charset-normalizer==3.3.0 - # via aiohttp -exceptiongroup==1.1.3 - # via pytest -frozenlist==1.4.0 - # via - # aiohttp - # aiosignal -idna==3.4 - # via yarl -iniconfig==2.0.0 - # via pytest -multidict==6.0.4 - # via - # aiohttp - # yarl -packaging==23.2 - # via pytest -pluggy==1.3.0 - # via pytest -pytest==7.4.2 - # via pytest-asyncio -pytest-asyncio==0.21.1 - # via -r dev-requirements.in -tomli==2.0.1 - # via pytest -yarl==1.9.2 - # via aiohttp diff --git a/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py b/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py deleted file mode 100644 index 7a47fd2ffb..0000000000 --- a/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -.. currentmodule:: flytekitplugins.chatgpt - -This package contains things that are useful when extending Flytekit. - -.. autosummary:: - :template: custom.rst - :toctree: generated/ - - ChatGPTTask -""" - -from .task import ChatGPTTask diff --git a/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py deleted file mode 100644 index fd9d687a2b..0000000000 --- a/plugins/flytekit-chatgpt/flytekitplugins/chatgpt/task.py +++ /dev/null @@ -1,72 +0,0 @@ -import json -from typing import Any, Dict - -import aiohttp -from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource - -from flytekit import FlyteContextManager -from flytekit.core.external_api_task import ExternalApiTask -from flytekit.core.type_engine import TypeEngine -from flytekit.extend.backend.base_agent import get_agent_secret -from flytekit.models.literals import LiteralMap - - -class ChatGPTTask(ExternalApiTask): - """ - This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. - """ - - _openai_organization: str = None - _chatgpt_conf: Dict[str, Any] = None - - def __init__(self, name: str, config: Dict[str, Any], **kwargs): - if "openai_organization" not in config: - raise ValueError("The 'openai_organization' configuration variable is required") - - if "chatgpt_conf" not in config: - raise ValueError("The 'chatgpt_conf' configuration variable is required") - - self._openai_organization = config["openai_organization"] - self._chatgpt_conf = config["chatgpt_conf"] - - super().__init__(name=name, config=config, return_type=str, **kwargs) - - async def do( - self, - message: str = None, - ) -> DoTaskResponse: - self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - openai_url = "https://api.openai.com/v1/chat/completions" - data = json.dumps(self._chatgpt_conf) - - async with aiohttp.ClientSession() as session: - async with session.post( - url=openai_url, headers=get_header(openai_organization=self._openai_organization), data=data - ) as resp: - if resp.status != 200: - raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") - response = await resp.json() - - message = response["choices"][0]["message"]["content"] - - ctx = FlyteContextManager.current_context() - outputs = LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - message, - type(message), - TypeEngine.to_literal_type(type(message)), - ) - } - ).to_flyte_idl() - - return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) - - -def get_header(openai_organization: str): - return { - "OpenAI-Organization": openai_organization, - "Authorization": f"Bearer {get_agent_secret(secret_key='OPENAI_ACCESS_TOKEN')}", - "content-type": "application/json", - } diff --git a/plugins/flytekit-chatgpt/requirements.in b/plugins/flytekit-chatgpt/requirements.in deleted file mode 100644 index 35f1aae56e..0000000000 --- a/plugins/flytekit-chatgpt/requirements.in +++ /dev/null @@ -1,2 +0,0 @@ -. --e file:.#egg=flytekitplugins-chatgpt diff --git a/plugins/flytekit-chatgpt/requirements.txt b/plugins/flytekit-chatgpt/requirements.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/plugins/flytekit-chatgpt/setup.py b/plugins/flytekit-chatgpt/setup.py deleted file mode 100644 index da0cfc64c6..0000000000 --- a/plugins/flytekit-chatgpt/setup.py +++ /dev/null @@ -1,36 +0,0 @@ -from setuptools import setup - -PLUGIN_NAME = "chatgpt" - -microlib_name = f"flytekitplugins-{PLUGIN_NAME}" - -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] - -__version__ = "0.0.0+develop" - -setup( - name=microlib_name, - version=__version__, - author="flyteorg", - author_email="admin@flyte.org", - description="This package holds the Bigquery plugins for flytekit", - namespace_packages=["flytekitplugins"], - packages=[f"flytekitplugins.{PLUGIN_NAME}"], - install_requires=plugin_requires, - license="apache2", - python_requires=">=3.8", - classifiers=[ - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development", - "Topic :: Software Development :: Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - ], - entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, -) diff --git a/plugins/flytekit-chatgpt/tests/__init__.py b/plugins/flytekit-chatgpt/tests/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py b/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py deleted file mode 100644 index 008f14d1bc..0000000000 --- a/plugins/flytekit-chatgpt/tests/test_chatgpt_task.py +++ /dev/null @@ -1,71 +0,0 @@ -from unittest import mock - -import aioresponses -import pytest -from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource -from flytekitplugins.chatgpt import ChatGPTTask - -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine -from flytekit.models.literals import LiteralMap - - -@pytest.mark.asyncio -async def test_chatgpt_task(): - message = "TEST MESSAGE" - response_message = "Hello! How can I assist you today?" - organization = "TEST ORGANIZATION" - chatgpt_job = ChatGPTTask( - name="chatgpt", - config={ - "openai_organization": organization, - "chatgpt_conf": { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": message}], - "temperature": 0.7, - }, - }, - ) - ctx = FlyteContextManager.current_context() - - assert chatgpt_job._chatgpt_conf == { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": message}], - "temperature": 0.7, - } - assert chatgpt_job._openai_organization == organization - - mocked_token = "mocked_chatgpt_token" - mocked_context = mock.patch("flytekit.current_context", autospec=True).start() - mocked_context.return_value.secrets.get.return_value = mocked_token - openai_url = "https://api.openai.com/v1/chat/completions" - mock_do_response = { - "id": "chatcmpl-8AJDGV3GdDsTcdFJfc68OxuLqIJZr", - "object": "chat.completion", - "created": 1697467826, - "model": "gpt-3.5-turbo-0613", - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": response_message}, "finish_reason": "stop"} - ], - "usage": {"prompt_tokens": 8, "completion_tokens": 9, "total_tokens": 17}, - } - with aioresponses.aioresponses() as mocked: - mocked.post(openai_url, status=200, payload=mock_do_response) - res = await chatgpt_job.do(message=message) - - assert res.resource.state == SUCCEEDED - assert ( - res.resource.outputs - == LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - response_message, - type(response_message), - TypeEngine.to_literal_type(type(response_message)), - ) - } - ).to_flyte_idl() - ) - - mock.patch.stopall() From e6482ae3c06debc960165a44d57d5c7788d401d5 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 17:18:15 +0800 Subject: [PATCH 28/64] databricks api bug Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 5adb18ff79..e1d2de37a2 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -157,6 +157,7 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() + # timedout is weird but correct, you can refer here: https://github.com/databricks/databricks-sdk-py/pull/407 if state in ["failed", "timeout", "timedout", "canceled"]: return RETRYABLE_FAILURE elif state in ["done", "succeeded", "success"]: From c6343ecc3ebabb7452e7dbfba6cf071d47a17e24 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 17:19:43 +0800 Subject: [PATCH 29/64] remove output_prefix Signed-off-by: Future Outlier --- flytekit/extend/backend/task_executor.py | 1 - tests/flytekit/unit/extend/test_agent.py | 10 ++++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index d2d8a6cd45..7608c08878 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -25,7 +25,6 @@ async def async_do( context: grpc.ServicerContext, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - output_prefix: Optional[str] = None, ) -> DoTaskResponse: python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index bf3197dfa9..cba4b25f68 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -70,7 +70,6 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT def do( self, context: grpc.ServicerContext, - output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -99,7 +98,6 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async def async_do( self, context: grpc.ServicerContext, - output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -159,7 +157,7 @@ def test_dummy_agent(): assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() - assert agent.do(ctx, "/tmp", dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) + assert agent.do(ctx, dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): @@ -187,7 +185,7 @@ async def test_async_dummy_agent(): assert res.resource.state == SUCCEEDED res = await agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse() - res = await agent.async_do(ctx, "/tmp", async_dummy_template, task_inputs) + res = await agent.async_do(ctx, async_dummy_template, task_inputs) assert res == DoTaskResponse(resource=Resource(state=SUCCEEDED)) @@ -202,10 +200,10 @@ async def run_agent_server(): inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + inputs=task_inputs.to_flyte_idl(), template=dummy_template.to_flyte_idl() ) async_do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() + inputs=task_inputs.to_flyte_idl(), template=async_dummy_template.to_flyte_idl() ) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") From ee0b829cc15fc930e6f62aec500210d431710c10 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 17 Oct 2023 17:20:03 +0800 Subject: [PATCH 30/64] make lint will make logger can't be import Signed-off-by: Future Outlier --- flytekit/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 05c5f054c2..04bf9d19a8 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -212,7 +212,6 @@ from flytekit.core.container_task import ContainerTask from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager from flytekit.core.dynamic_workflow_task import dynamic -from flytekit.core.external_api_task import ExternalApiTask from flytekit.core.gate import approve, sleep, wait_for_input from flytekit.core.hash import HashMethod from flytekit.core.launch_plan import LaunchPlan, reference_launch_plan @@ -229,7 +228,6 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.image_spec import ImageSpec from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -239,6 +237,7 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.sensor.sensor_engine import SensorEngine +from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.types import directory, file, iterator from flytekit.types.structured.structured_dataset import ( StructuredDataset, From a9b16c8d6418c948392957134f46f9e17cd26575 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 18 Oct 2023 12:18:45 +0800 Subject: [PATCH 31/64] improve excutor task Signed-off-by: Future Outlier --- flytekit/__init__.py | 2 +- tests/flytekit/unit/extend/test_agent.py | 8 ++--- .../unit/extend/test_task_executor.py | 36 +++++++++++++++++-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 04bf9d19a8..7235506412 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -228,6 +228,7 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck +from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.image_spec import ImageSpec from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -237,7 +238,6 @@ from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType from flytekit.sensor.sensor_engine import SensorEngine -from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.types import directory, file, iterator from flytekit.types.structured.structured_dataset import ( StructuredDataset, diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index cba4b25f68..4133215fd7 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -199,12 +199,8 @@ async def run_agent_server(): async_create_request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) - do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), template=dummy_template.to_flyte_idl() - ) - async_do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), template=async_dummy_template.to_flyte_idl() - ) + do_request = DoTaskRequest(inputs=task_inputs.to_flyte_idl(), template=dummy_template.to_flyte_idl()) + async_do_request = DoTaskRequest(inputs=task_inputs.to_flyte_idl(), template=async_dummy_template.to_flyte_idl()) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py index 51213cc4ff..702eede92b 100644 --- a/tests/flytekit/unit/extend/test_task_executor.py +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -3,17 +3,32 @@ import grpc import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from flytekit import FlyteContextManager from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, ExternalApiTask from flytekit.core.interface import Interface, transform_interface_to_typed_interface +from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals +from flytekit.models.literals import LiteralMap from tests.flytekit.unit.extend.test_agent import get_task_template class MockExternalApiTask(ExternalApiTask): - async def do(self, input: str, **kwargs) -> str: - return input + async def do(self, input: str, **kwargs) -> DoTaskResponse: + ctx = FlyteContextManager.current_context() + outputs = LiteralMap( + { + "o0": TypeEngine.to_literal( + ctx, + input, + type(input), + TypeEngine.to_literal_type(type(input)), + ) + } + ).to_flyte_idl() + return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs, message=input)) @pytest.mark.asyncio @@ -42,4 +57,19 @@ async def test_task_executor_engine(): agent = AgentRegistry.get_agent("api_task") res = await agent.async_do(ctx, tmp, task_inputs) - assert res == input + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs + == literals.LiteralMap( + { + "o0": literals.Literal( + scalar=literals.Scalar( + primitive=literals.Primitive( + string_value=input, + ) + ) + ) + } + ).to_flyte_idl() + ) + assert res.resource.message == input From a660fc1bae45f1da314eb1d87389d8e7701d6a92 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 20 Oct 2023 22:47:21 +0800 Subject: [PATCH 32/64] FLYTE_OPENAI_ACCESS_TOKEN Signed-off-by: Future Outlier --- plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index fd9d687a2b..6c29702c9f 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -67,6 +67,6 @@ async def do( def get_header(openai_organization: str): return { "OpenAI-Organization": openai_organization, - "Authorization": f"Bearer {get_agent_secret(secret_key='OPENAI_ACCESS_TOKEN')}", + "Authorization": f"Bearer {get_agent_secret(secret_key='FLYTE_OPENAI_ACCESS_TOKEN')}", "content-type": "application/json", } From a3b0ecc6e21d03a184e402a47984216686930e20 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Sat, 28 Oct 2023 12:38:12 +0800 Subject: [PATCH 33/64] add requirements Signed-off-by: Future Outlier --- .../flytekit-openai-chatgpt/requirements.txt | 362 ++++++++++++++++++ 1 file changed, 362 insertions(+) diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt index e69de29bb2..d3544bb32d 100644 --- a/plugins/flytekit-openai-chatgpt/requirements.txt +++ b/plugins/flytekit-openai-chatgpt/requirements.txt @@ -0,0 +1,362 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-chatgpt + # via -r requirements.in +adlfs==2023.10.0 + # via flytekit +aiobotocore==2.5.4 + # via s3fs +aiohttp==3.8.6 + # via + # adlfs + # aiobotocore + # flytekitplugins-chatgpt + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp +arrow==1.3.0 + # via cookiecutter +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.29.5 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.15.0 + # via adlfs +azure-storage-blob==12.18.3 + # via adlfs +binaryornot==0.4.4 + # via cookiecutter +botocore==1.31.17 + # via aiobotocore +cachetools==5.3.2 + # via google-auth +certifi==2023.7.22 + # via + # kubernetes + # requests +cffi==1.16.0 + # via + # azure-datalake-store + # cryptography +chardet==5.2.0 + # via binaryornot +charset-normalizer==3.3.1 + # via + # aiohttp + # requests +click==8.1.7 + # via + # cookiecutter + # flytekit + # rich-click +cloudpickle==3.0.0 + # via flytekit +cookiecutter==2.4.0 + # via flytekit +croniter==2.0.1 + # via flytekit +cryptography==41.0.5 + # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl + # secretstorage +dataclasses-json==0.5.9 + # via flytekit +decorator==5.1.1 + # via gcsfs +deprecated==1.2.14 + # via flytekit +diskcache==5.6.3 + # via flytekit +docker==6.1.3 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.10.0 + # via flytekit +flytekit==1.10.1b0 + # via flytekitplugins-chatgpt +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +fsspec==2023.9.2 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.9.2 + # via flytekit +gitdb==4.0.11 + # via gitpython +gitpython==3.1.40 + # via flytekit +google-api-core==2.12.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.23.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.1.0 + # via gcsfs +google-cloud-core==2.3.3 + # via google-cloud-storage +google-cloud-storage==2.12.0 + # via gcsfs +google-crc32c==1.5.0 + # via + # google-cloud-storage + # google-resumable-media +google-resumable-media==2.6.0 + # via google-cloud-storage +googleapis-common-protos==1.61.0 + # via + # flyteidl + # flytekit + # google-api-core + # grpcio-status +grpcio==1.59.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.59.0 + # via flytekit +idna==3.4 + # via + # requests + # yarl +importlib-metadata==6.8.0 + # via + # flytekit + # keyring +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.3.0 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.1.2 + # via cookiecutter +jmespath==1.0.1 + # via botocore +joblib==1.3.2 + # via flytekit +jsonpickle==3.0.2 + # via flytekit +keyring==24.2.0 + # via flytekit +kubernetes==28.1.0 + # via flytekit +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 + # via jinja2 +marshmallow==3.20.1 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via + # dataclasses-json + # flytekit +marshmallow-jsonschema==0.13.0 + # via flytekit +mashumaro==3.10 + # via flytekit +mdurl==0.1.2 + # via markdown-it-py +more-itertools==10.1.0 + # via jaraco-classes +msal==1.24.1 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.4.0 + # via flytekit +numpy==1.26.1 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via + # kubernetes + # requests-oauthlib +packaging==23.2 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +portalocker==2.8.2 + # via msal-extensions +protobuf==4.24.4 + # via + # flyteidl + # google-api-core + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +pyarrow==10.0.1 + # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pygments==2.16.1 + # via rich +pyjwt[crypto]==2.8.0 + # via + # msal + # pyjwt +pyopenssl==23.3.0 + # via flytekit +python-dateutil==2.8.2 + # via + # arrow + # botocore + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3.post1 + # via + # croniter + # flytekit + # pandas +pyyaml==6.0.1 + # via + # cookiecutter + # flytekit + # kubernetes +regex==2023.10.3 + # via docker-image-py +requests==2.31.0 + # via + # azure-core + # azure-datalake-store + # cookiecutter + # docker + # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +rich==13.6.0 + # via + # cookiecutter + # flytekit + # rich-click +rich-click==1.7.0 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.9.2 + # via flytekit +secretstorage==3.3.3 + # via keyring +six==1.16.0 + # via + # azure-core + # isodate + # kubernetes + # python-dateutil +smmap==5.0.1 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-python-dateutil==2.8.19.14 + # via arrow +typing-extensions==4.8.0 + # via + # aioitertools + # azure-core + # azure-storage-blob + # flytekit + # mashumaro + # rich-click + # typing-inspect +typing-inspect==0.9.0 + # via dataclasses-json +urllib3==1.26.18 + # via + # botocore + # docker + # flytekit + # kubernetes + # requests +websocket-client==1.6.4 + # via + # docker + # kubernetes +wheel==0.41.2 + # via flytekit +wrapt==1.15.0 + # via + # aiobotocore + # deprecated + # flytekit +yarl==1.9.2 + # via aiohttp +zipp==3.17.0 + # via importlib-metadata From f785a19ad09a7ca36937683da779919e0349dfee Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 1 Nov 2023 20:41:21 +0800 Subject: [PATCH 34/64] use openai async create function and change test Signed-off-by: Future Outlier --- .../dev-requirements.in | 1 - .../dev-requirements.txt | 24 ------- .../flytekitplugins/chatgpt/task.py | 31 +++------ .../flytekit-openai-chatgpt/requirements.txt | 13 ++-- plugins/flytekit-openai-chatgpt/setup.py | 4 +- .../tests/test_chatgpt_task.py | 63 +++++-------------- 6 files changed, 35 insertions(+), 101 deletions(-) diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.in b/plugins/flytekit-openai-chatgpt/dev-requirements.in index 78d0eca127..2d73dba5b4 100644 --- a/plugins/flytekit-openai-chatgpt/dev-requirements.in +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.in @@ -1,2 +1 @@ -aioresponses pytest-asyncio diff --git a/plugins/flytekit-openai-chatgpt/dev-requirements.txt b/plugins/flytekit-openai-chatgpt/dev-requirements.txt index f8255ed178..1c37cda90d 100644 --- a/plugins/flytekit-openai-chatgpt/dev-requirements.txt +++ b/plugins/flytekit-openai-chatgpt/dev-requirements.txt @@ -4,32 +4,10 @@ # # pip-compile dev-requirements.in # -aiohttp==3.8.6 - # via aioresponses -aioresponses==0.7.4 - # via -r dev-requirements.in -aiosignal==1.3.1 - # via aiohttp -async-timeout==4.0.3 - # via aiohttp -attrs==23.1.0 - # via aiohttp -charset-normalizer==3.3.0 - # via aiohttp exceptiongroup==1.1.3 # via pytest -frozenlist==1.4.0 - # via - # aiohttp - # aiosignal -idna==3.4 - # via yarl iniconfig==2.0.0 # via pytest -multidict==6.0.4 - # via - # aiohttp - # yarl packaging==23.2 # via pytest pluggy==1.3.0 @@ -40,5 +18,3 @@ pytest-asyncio==0.21.1 # via -r dev-requirements.in tomli==2.0.1 # via pytest -yarl==1.9.2 - # via aiohttp diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 6c29702c9f..02cfdf16f7 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,7 +1,6 @@ -import json from typing import Any, Dict -import aiohttp +import openai from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager @@ -26,6 +25,9 @@ def __init__(self, name: str, config: Dict[str, Any], **kwargs): if "chatgpt_conf" not in config: raise ValueError("The 'chatgpt_conf' configuration variable is required") + if "model" not in config["chatgpt_conf"]: + raise ValueError("The 'model' configuration variable in 'chatgpt_conf' is required") + self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] @@ -35,19 +37,13 @@ async def do( self, message: str = None, ) -> DoTaskResponse: - self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - openai_url = "https://api.openai.com/v1/chat/completions" - data = json.dumps(self._chatgpt_conf) + openai.organization = self._openai_organization + openai.api_key = get_agent_secret(secret_key="FLYTE_OPENAI_ACCESS_TOKEN") - async with aiohttp.ClientSession() as session: - async with session.post( - url=openai_url, headers=get_header(openai_organization=self._openai_organization), data=data - ) as resp: - if resp.status != 200: - raise Exception(f"Failed to execute chatgpt job with error: {resp.reason}") - response = await resp.json() + self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - message = response["choices"][0]["message"]["content"] + completion = await openai.ChatCompletion.acreate(**self._chatgpt_conf) + message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() outputs = LiteralMap( @@ -60,13 +56,4 @@ async def do( ) } ).to_flyte_idl() - return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) - - -def get_header(openai_organization: str): - return { - "OpenAI-Organization": openai_organization, - "Authorization": f"Bearer {get_agent_secret(secret_key='FLYTE_OPENAI_ACCESS_TOKEN')}", - "content-type": "application/json", - } diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt index d3544bb32d..37662e62f3 100644 --- a/plugins/flytekit-openai-chatgpt/requirements.txt +++ b/plugins/flytekit-openai-chatgpt/requirements.txt @@ -14,8 +14,8 @@ aiohttp==3.8.6 # via # adlfs # aiobotocore - # flytekitplugins-chatgpt # gcsfs + # openai # s3fs aioitertools==0.11.0 # via aiobotocore @@ -141,11 +141,11 @@ googleapis-common-protos==1.61.0 # flytekit # google-api-core # grpcio-status -grpcio==1.59.0 +grpcio==1.59.2 # via # flytekit # grpcio-status -grpcio-status==1.59.0 +grpcio-status==1.59.2 # via flytekit idna==3.4 # via @@ -220,6 +220,8 @@ oauthlib==3.2.2 # via # kubernetes # requests-oauthlib +openai==0.28.1 + # via flytekitplugins-chatgpt packaging==23.2 # via # docker @@ -293,6 +295,7 @@ requests==2.31.0 # google-cloud-storage # kubernetes # msal + # openai # requests-oauthlib requests-oauthlib==1.3.1 # via @@ -325,6 +328,8 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify +tqdm==4.66.1 + # via openai types-python-dateutil==2.8.19.14 # via arrow typing-extensions==4.8.0 @@ -349,7 +354,7 @@ websocket-client==1.6.4 # via # docker # kubernetes -wheel==0.41.2 +wheel==0.41.3 # via flytekit wrapt==1.15.0 # via diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py index 401b0df201..f395dc5c80 100644 --- a/plugins/flytekit-openai-chatgpt/setup.py +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "aiohttp"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "openai>=0.28.1"] __version__ = "0.0.0+develop" @@ -13,7 +13,7 @@ version=__version__, author="flyteorg", author_email="admin@flyte.org", - description="This package holds the Bigquery plugins for flytekit", + description="This package holds the ChatGPT plugins for flytekit", namespace_packages=["flytekitplugins"], packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, diff --git a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py index 008f14d1bc..c7ba185b79 100644 --- a/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py +++ b/plugins/flytekit-openai-chatgpt/tests/test_chatgpt_task.py @@ -1,20 +1,23 @@ from unittest import mock -import aioresponses import pytest -from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from flyteidl.admin.agent_pb2 import SUCCEEDED from flytekitplugins.chatgpt import ChatGPTTask -from flytekit import FlyteContextManager -from flytekit.core.type_engine import TypeEngine -from flytekit.models.literals import LiteralMap + +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response @pytest.mark.asyncio -async def test_chatgpt_task(): +async def test_chatgpt_task_do(): message = "TEST MESSAGE" - response_message = "Hello! How can I assist you today?" organization = "TEST ORGANIZATION" + chatgpt_job = ChatGPTTask( name="chatgpt", config={ @@ -26,46 +29,10 @@ async def test_chatgpt_task(): }, }, ) - ctx = FlyteContextManager.current_context() - - assert chatgpt_job._chatgpt_conf == { - "model": "gpt-3.5-turbo", - "messages": [{"role": "user", "content": message}], - "temperature": 0.7, - } - assert chatgpt_job._openai_organization == organization - - mocked_token = "mocked_chatgpt_token" - mocked_context = mock.patch("flytekit.current_context", autospec=True).start() - mocked_context.return_value.secrets.get.return_value = mocked_token - openai_url = "https://api.openai.com/v1/chat/completions" - mock_do_response = { - "id": "chatcmpl-8AJDGV3GdDsTcdFJfc68OxuLqIJZr", - "object": "chat.completion", - "created": 1697467826, - "model": "gpt-3.5-turbo-0613", - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": response_message}, "finish_reason": "stop"} - ], - "usage": {"prompt_tokens": 8, "completion_tokens": 9, "total_tokens": 17}, - } - with aioresponses.aioresponses() as mocked: - mocked.post(openai_url, status=200, payload=mock_do_response) - res = await chatgpt_job.do(message=message) - assert res.resource.state == SUCCEEDED - assert ( - res.resource.outputs - == LiteralMap( - { - "o0": TypeEngine.to_literal( - ctx, - response_message, - type(response_message), - TypeEngine.to_literal_type(type(response_message)), - ) - } - ).to_flyte_idl() - ) + with mock.patch("openai.ChatCompletion.acreate", new=mock_acreate): + with mock.patch("flytekit.extend.backend.base_agent.get_agent_secret", return_value="mocked_secret"): + response = await chatgpt_job.do(message=message) - mock.patch.stopall() + assert response.resource.state == SUCCEEDED + assert "mocked_message" in str(response.resource.outputs) From 8892dee66d7eca0b883c9cb789935c45bdd1397e Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 3 Nov 2023 23:43:15 +0800 Subject: [PATCH 35/64] change the runtime flavor by specifying str value Signed-off-by: Future Outlier --- flytekit/core/base_task.py | 14 +++++++------- flytekit/core/external_api_task.py | 4 ++-- tests/flytekit/unit/core/test_task_metadata.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 9262987e72..5027012c14 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -105,7 +105,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None - use_sync_plugin: bool = False + runtime_flavor: str = "python" def __post_init__(self): if self.timeout: @@ -133,7 +133,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: runtime=_task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - "sync_plugin" if self.use_sync_plugin else "python", + self.runtime_flavor, ), timeout=self.timeout, retries=self.retry_strategy, @@ -173,13 +173,13 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, - use_sync_plugin: bool = False, + runtime_flavor: str = "python", **kwargs, ): self._task_type = task_type self._name = name self._interface = interface - self._metadata = metadata if metadata else TaskMetadata(use_sync_plugin=use_sync_plugin) + self._metadata = metadata if metadata else TaskMetadata(runtime_flavor=runtime_flavor) self._task_type_version = task_type_version self._security_ctx = security_ctx self._docs = docs @@ -423,7 +423,7 @@ def __init__( environment: Optional[Dict[str, str]] = None, disable_deck: Optional[bool] = None, enable_deck: Optional[bool] = None, - use_sync_plugin: bool = False, + runtime_flavor: str = "python", **kwargs, ): """ @@ -439,13 +439,13 @@ def __init__( execution of the task. Supplied as a dictionary of key/value pairs disable_deck (bool): (deprecated) If true, this task will not output deck html file enable_deck (bool): If true, this task will output deck html file - use_sync_plugin (bool): If true, this task will invoke sync plugin in flytepropeller and flyteplugin + runtime_flavor (str): default is "python", we can set it to "sync_plugin" for flytepropeller to execute sync plugin task """ super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), - use_sync_plugin=use_sync_plugin, + runtime_flavor=runtime_flavor, **kwargs, ) self._python_interface = interface if interface else Interface() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 2799f3ae74..a0d0173590 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -17,7 +17,7 @@ TASK_NAME = "task_name" TASK_CONFIG_PKL = "task_config_pkl" TASK_TYPE = "api_task" -USE_SYNC_PLUGIN = "use_sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task +USE_SYNC_PLUGIN = "sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): @@ -49,7 +49,7 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), - use_sync_plugin=True, + runtime_flavor=USE_SYNC_PLUGIN, **kwargs, ) diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 003753dbfc..73fd5eef55 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -41,7 +41,7 @@ def test_to_taskmetadata_model(): retries=3, timeout=3600, pod_template_name="TEST POD TEMPLATE NAME", - use_sync_plugin=True, + runtime_flavor="sync_plugin", ) model = tm.to_taskmetadata_model() From 1810a7be2db199c377aac0c829bd878f160ddffe Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 6 Nov 2023 16:10:38 +0800 Subject: [PATCH 36/64] add depedencies Signed-off-by: Future Outlier --- plugins/flytekit-openai-chatgpt/requirements.txt | 4 +++- plugins/flytekit-openai-chatgpt/setup.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt index 37662e62f3..3afcb68d04 100644 --- a/plugins/flytekit-openai-chatgpt/requirements.txt +++ b/plugins/flytekit-openai-chatgpt/requirements.txt @@ -92,7 +92,9 @@ docker-image-py==0.1.12 docstring-parser==0.15 # via flytekit flyteidl==1.10.0 - # via flytekit + # via + # flytekit + # flytekitplugins-chatgpt flytekit==1.10.1b0 # via flytekitplugins-chatgpt frozenlist==1.4.0 diff --git a/plugins/flytekit-openai-chatgpt/setup.py b/plugins/flytekit-openai-chatgpt/setup.py index f395dc5c80..ba27a45a4b 100644 --- a/plugins/flytekit-openai-chatgpt/setup.py +++ b/plugins/flytekit-openai-chatgpt/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "openai>=0.28.1"] +plugin_requires = ["flytekit>=1.10.0", "openai>=0.28.1", "flyteidl>=1.10.0"] __version__ = "0.0.0+develop" From 8d5bb61a57aa83113aa139792b4442e86d378de6 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 7 Nov 2023 22:29:39 +0800 Subject: [PATCH 37/64] add timeout seconds in ChatGPT Signed-off-by: Future Outlier --- plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 02cfdf16f7..00cd380d1f 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -9,6 +9,7 @@ from flytekit.extend.backend.base_agent import get_agent_secret from flytekit.models.literals import LiteralMap +TIMEOUT_SECONDS = 10 class ChatGPTTask(ExternalApiTask): """ @@ -41,6 +42,7 @@ async def do( openai.api_key = get_agent_secret(secret_key="FLYTE_OPENAI_ACCESS_TOKEN") self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] + self._chatgpt_conf["timeout"] = TIMEOUT_SECONDS completion = await openai.ChatCompletion.acreate(**self._chatgpt_conf) message = completion.choices[0].message.content From fd8dd5d91de385133c5e2f044d4883f1bf0bd9c4 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 8 Nov 2023 12:30:31 +0800 Subject: [PATCH 38/64] change await place in async do task function Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index e1d2de37a2..e91e0bb619 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -247,10 +247,10 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): inputs = self.get_input_literal_map(inputs) if self._agent.asynchronous: - res = self._agent.async_do(self._grpc_ctx, task_template, inputs) + res = await self._agent.async_do(self._grpc_ctx, task_template, inputs) else: res = self._agent.do(self._grpc_ctx, task_template, inputs) - return await res + return res def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: if self._agent.asynchronous: From e6b0ba0d2f9d178fdc44e6cf93b98d8d0036fd25 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 8 Nov 2023 00:13:34 +0800 Subject: [PATCH 39/64] fix circular import Signed-off-by: Future Outlier --- flytekit/__init__.py | 3 +- flytekit/extend/backend/base_agent.py | 4 +- .../flytekitplugins/chatgpt/task.py | 1 + .../flytekit-openai-chatgpt/requirements.in | 2 - .../flytekit-openai-chatgpt/requirements.txt | 369 ------------------ 5 files changed, 6 insertions(+), 373 deletions(-) delete mode 100644 plugins/flytekit-openai-chatgpt/requirements.in delete mode 100644 plugins/flytekit-openai-chatgpt/requirements.txt diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 7235506412..5d634d9801 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -228,7 +228,6 @@ from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck -from flytekit.extend.backend.task_executor import TaskExecutor from flytekit.image_spec import ImageSpec from flytekit.loggers import LOGGING_RICH_FMT_ENV_VAR, logger from flytekit.models.common import Annotations, AuthRole, Labels @@ -246,6 +245,8 @@ StructuredDatasetType, ) +from flytekit.extend.backend.task_executor import TaskExecutor # isort:skip. This is for circular import avoidance. + __version__ = "0.0.0+develop" diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index e91e0bb619..d2ccc68c20 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -23,7 +23,7 @@ from flyteidl.core.tasks_pb2 import TaskTemplate import flytekit -from flytekit import FlyteContext +from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.type_engine import TypeEngine @@ -144,6 +144,7 @@ def register(agent: AgentBase): if agent.task_type in AgentRegistry._REGISTRY: raise ValueError(f"Duplicate agent for task type {agent.task_type}") AgentRegistry._REGISTRY[agent.task_type] = agent + logger.info(f"Registering an agent for task type {agent.task_type}") @staticmethod def get_agent(task_type: str) -> typing.Optional[AgentBase]: @@ -242,6 +243,7 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: else: res = self._agent.get(self._grpc_ctx, resource_meta) state = res.resource.state + logger.info(f"Task state: {state}") return res async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 00cd380d1f..faec04957b 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -11,6 +11,7 @@ TIMEOUT_SECONDS = 10 + class ChatGPTTask(ExternalApiTask): """ This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. diff --git a/plugins/flytekit-openai-chatgpt/requirements.in b/plugins/flytekit-openai-chatgpt/requirements.in deleted file mode 100644 index 35f1aae56e..0000000000 --- a/plugins/flytekit-openai-chatgpt/requirements.in +++ /dev/null @@ -1,2 +0,0 @@ -. --e file:.#egg=flytekitplugins-chatgpt diff --git a/plugins/flytekit-openai-chatgpt/requirements.txt b/plugins/flytekit-openai-chatgpt/requirements.txt deleted file mode 100644 index 3afcb68d04..0000000000 --- a/plugins/flytekit-openai-chatgpt/requirements.txt +++ /dev/null @@ -1,369 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile requirements.in -# --e file:.#egg=flytekitplugins-chatgpt - # via -r requirements.in -adlfs==2023.10.0 - # via flytekit -aiobotocore==2.5.4 - # via s3fs -aiohttp==3.8.6 - # via - # adlfs - # aiobotocore - # gcsfs - # openai - # s3fs -aioitertools==0.11.0 - # via aiobotocore -aiosignal==1.3.1 - # via aiohttp -arrow==1.3.0 - # via cookiecutter -async-timeout==4.0.3 - # via aiohttp -attrs==23.1.0 - # via aiohttp -azure-core==1.29.5 - # via - # adlfs - # azure-identity - # azure-storage-blob -azure-datalake-store==0.0.53 - # via adlfs -azure-identity==1.15.0 - # via adlfs -azure-storage-blob==12.18.3 - # via adlfs -binaryornot==0.4.4 - # via cookiecutter -botocore==1.31.17 - # via aiobotocore -cachetools==5.3.2 - # via google-auth -certifi==2023.7.22 - # via - # kubernetes - # requests -cffi==1.16.0 - # via - # azure-datalake-store - # cryptography -chardet==5.2.0 - # via binaryornot -charset-normalizer==3.3.1 - # via - # aiohttp - # requests -click==8.1.7 - # via - # cookiecutter - # flytekit - # rich-click -cloudpickle==3.0.0 - # via flytekit -cookiecutter==2.4.0 - # via flytekit -croniter==2.0.1 - # via flytekit -cryptography==41.0.5 - # via - # azure-identity - # azure-storage-blob - # msal - # pyjwt - # pyopenssl - # secretstorage -dataclasses-json==0.5.9 - # via flytekit -decorator==5.1.1 - # via gcsfs -deprecated==1.2.14 - # via flytekit -diskcache==5.6.3 - # via flytekit -docker==6.1.3 - # via flytekit -docker-image-py==0.1.12 - # via flytekit -docstring-parser==0.15 - # via flytekit -flyteidl==1.10.0 - # via - # flytekit - # flytekitplugins-chatgpt -flytekit==1.10.1b0 - # via flytekitplugins-chatgpt -frozenlist==1.4.0 - # via - # aiohttp - # aiosignal -fsspec==2023.9.2 - # via - # adlfs - # flytekit - # gcsfs - # s3fs -gcsfs==2023.9.2 - # via flytekit -gitdb==4.0.11 - # via gitpython -gitpython==3.1.40 - # via flytekit -google-api-core==2.12.0 - # via - # google-cloud-core - # google-cloud-storage -google-auth==2.23.3 - # via - # gcsfs - # google-api-core - # google-auth-oauthlib - # google-cloud-core - # google-cloud-storage - # kubernetes -google-auth-oauthlib==1.1.0 - # via gcsfs -google-cloud-core==2.3.3 - # via google-cloud-storage -google-cloud-storage==2.12.0 - # via gcsfs -google-crc32c==1.5.0 - # via - # google-cloud-storage - # google-resumable-media -google-resumable-media==2.6.0 - # via google-cloud-storage -googleapis-common-protos==1.61.0 - # via - # flyteidl - # flytekit - # google-api-core - # grpcio-status -grpcio==1.59.2 - # via - # flytekit - # grpcio-status -grpcio-status==1.59.2 - # via flytekit -idna==3.4 - # via - # requests - # yarl -importlib-metadata==6.8.0 - # via - # flytekit - # keyring -isodate==0.6.1 - # via azure-storage-blob -jaraco-classes==3.3.0 - # via keyring -jeepney==0.8.0 - # via - # keyring - # secretstorage -jinja2==3.1.2 - # via cookiecutter -jmespath==1.0.1 - # via botocore -joblib==1.3.2 - # via flytekit -jsonpickle==3.0.2 - # via flytekit -keyring==24.2.0 - # via flytekit -kubernetes==28.1.0 - # via flytekit -markdown-it-py==3.0.0 - # via rich -markupsafe==2.1.3 - # via jinja2 -marshmallow==3.20.1 - # via - # dataclasses-json - # marshmallow-enum - # marshmallow-jsonschema -marshmallow-enum==1.5.1 - # via - # dataclasses-json - # flytekit -marshmallow-jsonschema==0.13.0 - # via flytekit -mashumaro==3.10 - # via flytekit -mdurl==0.1.2 - # via markdown-it-py -more-itertools==10.1.0 - # via jaraco-classes -msal==1.24.1 - # via - # azure-datalake-store - # azure-identity - # msal-extensions -msal-extensions==1.0.0 - # via azure-identity -multidict==6.0.4 - # via - # aiohttp - # yarl -mypy-extensions==1.0.0 - # via typing-inspect -natsort==8.4.0 - # via flytekit -numpy==1.26.1 - # via - # flytekit - # pandas - # pyarrow -oauthlib==3.2.2 - # via - # kubernetes - # requests-oauthlib -openai==0.28.1 - # via flytekitplugins-chatgpt -packaging==23.2 - # via - # docker - # marshmallow -pandas==1.5.3 - # via flytekit -portalocker==2.8.2 - # via msal-extensions -protobuf==4.24.4 - # via - # flyteidl - # google-api-core - # googleapis-common-protos - # grpcio-status - # protoc-gen-swagger -protoc-gen-swagger==0.1.0 - # via flyteidl -pyarrow==10.0.1 - # via flytekit -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycparser==2.21 - # via cffi -pygments==2.16.1 - # via rich -pyjwt[crypto]==2.8.0 - # via - # msal - # pyjwt -pyopenssl==23.3.0 - # via flytekit -python-dateutil==2.8.2 - # via - # arrow - # botocore - # croniter - # flytekit - # kubernetes - # pandas -python-json-logger==2.0.7 - # via flytekit -python-slugify==8.0.1 - # via cookiecutter -pytimeparse==1.1.8 - # via flytekit -pytz==2023.3.post1 - # via - # croniter - # flytekit - # pandas -pyyaml==6.0.1 - # via - # cookiecutter - # flytekit - # kubernetes -regex==2023.10.3 - # via docker-image-py -requests==2.31.0 - # via - # azure-core - # azure-datalake-store - # cookiecutter - # docker - # flytekit - # gcsfs - # google-api-core - # google-cloud-storage - # kubernetes - # msal - # openai - # requests-oauthlib -requests-oauthlib==1.3.1 - # via - # google-auth-oauthlib - # kubernetes -rich==13.6.0 - # via - # cookiecutter - # flytekit - # rich-click -rich-click==1.7.0 - # via flytekit -rsa==4.9 - # via google-auth -s3fs==2023.9.2 - # via flytekit -secretstorage==3.3.3 - # via keyring -six==1.16.0 - # via - # azure-core - # isodate - # kubernetes - # python-dateutil -smmap==5.0.1 - # via gitdb -sortedcontainers==2.4.0 - # via flytekit -statsd==3.3.0 - # via flytekit -text-unidecode==1.3 - # via python-slugify -tqdm==4.66.1 - # via openai -types-python-dateutil==2.8.19.14 - # via arrow -typing-extensions==4.8.0 - # via - # aioitertools - # azure-core - # azure-storage-blob - # flytekit - # mashumaro - # rich-click - # typing-inspect -typing-inspect==0.9.0 - # via dataclasses-json -urllib3==1.26.18 - # via - # botocore - # docker - # flytekit - # kubernetes - # requests -websocket-client==1.6.4 - # via - # docker - # kubernetes -wheel==0.41.3 - # via flytekit -wrapt==1.15.0 - # via - # aiobotocore - # deprecated - # flytekit -yarl==1.9.2 - # via aiohttp -zipp==3.17.0 - # via importlib-metadata From dc87c81ffeec0bbec944eaf21293850998929ad1 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 8 Nov 2023 18:13:05 +0800 Subject: [PATCH 40/64] Update flytekit/extend/backend/base_agent.py Co-authored-by: Kevin Su Signed-off-by: Future Outlier --- flytekit/extend/backend/base_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index d2ccc68c20..169ff5f820 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -158,7 +158,7 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - # timedout is weird but correct, you can refer here: https://github.com/databricks/databricks-sdk-py/pull/407 + # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate if state in ["failed", "timeout", "timedout", "canceled"]: return RETRYABLE_FAILURE elif state in ["done", "succeeded", "success"]: From 8ea1a9d4253fdbba8214e6c81a45908595a24f56 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 8 Nov 2023 18:13:40 +0800 Subject: [PATCH 41/64] Update tests/flytekit/unit/core/test_task_metadata.py Co-authored-by: Kevin Su Signed-off-by: Future Outlier --- tests/flytekit/unit/core/test_task_metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 73fd5eef55..1ab695d2a3 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -31,7 +31,7 @@ def test_retry_strategy(): assert tm.retry_strategy.retries == 5 -def test_to_taskmetadata_model(): +def test_to_task_metadata_model(): tm = TaskMetadata( cache=True, cache_serialize=True, From 42b8d9008e83a57efc2577e94689e866ce8ea4e6 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 9 Nov 2023 18:53:07 +0800 Subject: [PATCH 42/64] add runtime flavor variavle in all async plugin and sync plugin test Signed-off-by: Future Outlier --- flytekit/core/base_sql_task.py | 2 ++ flytekit/core/base_task.py | 8 ++++---- flytekit/core/external_api_task.py | 5 ++--- flytekit/core/python_auto_container.py | 2 ++ flytekit/core/python_function_task.py | 2 ++ flytekit/extend/backend/base_agent.py | 7 ++++++- flytekit/extend/backend/task_executor.py | 2 +- flytekit/models/task.py | 4 ++-- flytekit/sensor/base_sensor.py | 3 ++- flytekit/tools/translator.py | 5 +---- plugins/flytekit-airflow/flytekitplugins/airflow/task.py | 3 ++- .../flytekit-aws-athena/flytekitplugins/athena/task.py | 2 ++ .../flytekit-aws-batch/flytekitplugins/awsbatch/task.py | 7 ++++++- .../flytekit-bigquery/flytekitplugins/bigquery/task.py | 3 ++- plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py | 2 ++ .../flytekit-snowflake/flytekitplugins/snowflake/task.py | 3 ++- plugins/flytekit-spark/flytekitplugins/spark/task.py | 2 ++ tests/flytekit/unit/core/test_task_metadata.py | 9 +++++++++ tests/flytekit/unit/extend/test_agent.py | 2 ++ 19 files changed, 53 insertions(+), 20 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 30b73223a9..0d6295db88 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -27,6 +27,7 @@ def __init__( inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, outputs: Optional[Dict[str, Type]] = None, + runtime_flavor: Optional[str] = None, **kwargs, ): """ @@ -39,6 +40,7 @@ def __init__( interface=Interface(inputs=inputs or {}, outputs=outputs or {}), metadata=metadata, task_config=task_config, + runtime_flavor=runtime_flavor, **kwargs, ) self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 5027012c14..50d29f2350 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -105,7 +105,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None - runtime_flavor: str = "python" + runtime_flavor: Optional[str] = None def __post_init__(self): if self.timeout: @@ -173,7 +173,7 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, - runtime_flavor: str = "python", + runtime_flavor: Optional[str] = None, **kwargs, ): self._task_type = task_type @@ -423,7 +423,7 @@ def __init__( environment: Optional[Dict[str, str]] = None, disable_deck: Optional[bool] = None, enable_deck: Optional[bool] = None, - runtime_flavor: str = "python", + runtime_flavor: Optional[str] = None, **kwargs, ): """ @@ -439,7 +439,7 @@ def __init__( execution of the task. Supplied as a dictionary of key/value pairs disable_deck (bool): (deprecated) If true, this task will not output deck html file enable_deck (bool): If true, this task will output deck html file - runtime_flavor (str): default is "python", we can set it to "sync_plugin" for flytepropeller to execute sync plugin task + runtime_flavor (Optional[str]): we can set it to "sync_plugin" or "async_plugin" for flytepropeller to execute plugin task """ super().__init__( task_type=task_type, diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index a0d0173590..aa11accf48 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -10,14 +10,13 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import SYNC_PLUGIN, AsyncAgentExecutorMixin T = TypeVar("T") TASK_MODULE = "task_module" TASK_NAME = "task_name" TASK_CONFIG_PKL = "task_config_pkl" TASK_TYPE = "api_task" -USE_SYNC_PLUGIN = "sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): @@ -49,7 +48,7 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), - runtime_flavor=USE_SYNC_PLUGIN, + runtime_flavor=SYNC_PLUGIN, **kwargs, ) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 1ad1de0216..87fc6f7088 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -48,6 +48,7 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, + runtime_flavor: Optional[str] = None, **kwargs, ): """ @@ -92,6 +93,7 @@ def __init__( name=name, task_config=task_config, security_ctx=sec_ctx, + runtime_flavor=runtime_flavor, **kwargs, ) self._container_image = container_image diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index e1e80a4227..cba7f91629 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -102,6 +102,7 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, + runtime_flavor: Optional[str] = None, **kwargs, ): """ @@ -124,6 +125,7 @@ def __init__( interface=mutated_interface, task_config=task_config, task_resolver=task_resolver, + runtime_flavor=runtime_flavor, **kwargs, ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 169ff5f820..c24c5c80b9 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -31,6 +31,9 @@ from flytekit.exceptions.user import FlyteUserException from flytekit.models.literals import LiteralMap +SYNC_PLUGIN = "sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task +ASYNC_PLUGIN = "async_plugin" # Indicates that the async plugin in FlytePropeller should be used to run this task + class AgentBase(ABC): """ @@ -199,7 +202,7 @@ class AsyncAgentExecutorMixin: _grpc_ctx: grpc.ServicerContext = _get_grpc_context() def execute(self, **kwargs) -> typing.Any: - from flytekit.extend.backend.task_executor import TaskExecutor + from flytekit.extend.backend.task_executor import TaskExecutor # This is for circular import avoidance. from flytekit.tools.translator import get_serializable self._entity = typing.cast(PythonTask, self) @@ -263,6 +266,8 @@ def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> sys.exit(1) def get_input_literal_map(self, inputs: typing.Dict[str, typing.Any] = None) -> typing.Optional[LiteralMap]: + if inputs is None: + return None # Convert python inputs to literals literals = {} for k, v in inputs.items(): diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index 7608c08878..a36c0dab9b 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -15,7 +15,7 @@ T = typing.TypeVar("T") - +# TODO: ADD COMMENTS LIKE SENSOR ENGINE class TaskExecutor(AgentBase): def __init__(self): super().__init__(task_type=TASK_TYPE, asynchronous=True) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 48a8abfde1..4d8d386c86 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -121,7 +121,7 @@ def __init__(self, type, version, flavor): :param int type: Enum type from RuntimeMetadata.RuntimeType :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in Admin or Propeller - :param Text flavor: Optional extra information about runtime environment (e.g. Python, GoLang, etc.) + :param Text flavor: Optional extra information about the plugin type (e.g. async plugin, sync plugin... etc.). """ self._type = type self._version = version @@ -146,7 +146,7 @@ def version(self): @property def flavor(self): """ - Optional extra information about runtime environment (e.g. Python, GoLang, etc.) + Optional extra information about the plugin type (e.g. async plugin, sync plugin... etc.). :rtype: Text """ return self._flavor diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 0e40055ea5..23f2ecb955 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -9,7 +9,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin T = TypeVar("T") SENSOR_MODULE = "sensor_module" @@ -44,6 +44,7 @@ def __init__( name=name, task_config=None, interface=Interface(inputs=inputs), + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) self._sensor_config = sensor_config diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index ef0de87768..cfe43544f3 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Union +from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants from flytekit.core.array_node_map_task import ArrayNodeMapTask @@ -162,8 +163,6 @@ def get_serializable_task( settings: SerializationSettings, entity: FlyteLocalEntity, ) -> TaskSpec: - from flytekit import PythonFunctionTask - task_id = _identifier_model.Identifier( _identifier_model.ResourceType.TASK, settings.project, @@ -732,8 +731,6 @@ def get_serializable( raise Exception(f"Non serializable type found {type(entity)} Entity {entity}") if isinstance(entity, TaskSpec) or isinstance(entity, WorkflowSpec): - from flytekit import SourceCode - # 1. Check if the size of long description exceeds 16KB # 2. Extract the repo URL from the git config, and assign it to the link of the source code of the description entity if entity.docs and entity.docs.long_description: diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index a25a46cf1e..ebd1ba43c2 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -12,7 +12,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin @dataclass @@ -39,6 +39,7 @@ def __init__( query_template=query_template, interface=Interface(inputs=inputs or {}), task_type=self._TASK_TYPE, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) diff --git a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py index 1ae47339b3..2efca76e99 100644 --- a/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py +++ b/plugins/flytekit-aws-athena/flytekitplugins/athena/task.py @@ -5,6 +5,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.models.presto import PrestoQuery from flytekit.types.schema import FlyteSchema @@ -65,6 +66,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index e0326f112b..2438613d8a 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -8,6 +8,7 @@ from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN @dataclass @@ -41,7 +42,11 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg if task_config is None: task_config = AWSBatchConfig() super(AWSBatchFunctionTask, self).__init__( - task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, **kwargs + task_config=task_config, + task_type=self._AWS_BATCH_TASK_TYPE, + task_function=task_function, + runtime_flavor=ASYNC_PLUGIN, + **kwargs ) self._task_config = task_config diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index bcc707da5a..8268602022 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -7,7 +7,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -63,6 +63,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) self._output_structured_dataset_type = output_structured_dataset_type diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py index 3a61d590d7..d3d682b964 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py @@ -10,6 +10,7 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.resources import Resources from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.image_spec.image_spec import ImageSpec @@ -40,6 +41,7 @@ def __init__( task_type=self._TASK_TYPE, task_function=task_function, container_image=container_image, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 9ac9980a88..e3eaa9699a 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,7 +3,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -77,6 +77,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 17099350e4..b6c84fa687 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,6 +9,7 @@ from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins +from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -128,6 +129,7 @@ def __init__( task_type=self._SPARK_TASK_TYPE, task_function=task_function, container_image=container_image, + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 1ab695d2a3..6214c3f9a6 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -58,3 +58,12 @@ def test_to_task_metadata_model(): assert model.deprecated_error_message == "TEST DEPRECATED ERROR MESSAGE" assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" + + # Since the default value is not "python" anymore, add a test to test the default value + tm = TaskMetadata() + model = tm.to_taskmetadata_model() + assert model.runtime == _task_model.RuntimeMetadata( + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + None, + ) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 4013e025c6..16dbac3da3 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -27,6 +27,7 @@ from flytekit import PythonFunctionTask from flytekit.extend.backend.agent_service import AsyncAgentService from flytekit.extend.backend.base_agent import ( + ASYNC_PLUGIN, AgentBase, AgentRegistry, AsyncAgentExecutorMixin, @@ -160,6 +161,7 @@ class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="dummy", + runtime_flavor=ASYNC_PLUGIN, **kwargs, ) From 14a269820cd0fba162e78e919a01b2440bf2cec4 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Thu, 9 Nov 2023 19:17:57 +0800 Subject: [PATCH 43/64] add TaskExecutor comment Signed-off-by: Future Outlier --- flytekit/extend/backend/task_executor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index a36c0dab9b..e8510362fa 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -15,8 +15,16 @@ T = typing.TypeVar("T") -# TODO: ADD COMMENTS LIKE SENSOR ENGINE + class TaskExecutor(AgentBase): + """ + TaskExecutor is an agent responsible for executing external API tasks. + + This class is meant to be subclassed when implementing plugins that require + an external API to perform the task execution. It provides a routing mechanism + to direct the task to the appropriate handler based on the task's specifications. + """ + def __init__(self): super().__init__(task_type=TASK_TYPE, asynchronous=True) From 67582fd333b0cacd93c7f95d57a319d645967ce7 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 10 Nov 2023 11:21:07 +0800 Subject: [PATCH 44/64] change the argument task type used in sensor engine and api_task engine (task executor) Signed-off-by: Future Outlier --- flytekit/sensor/base_sensor.py | 1 + flytekit/sensor/sensor_engine.py | 4 ++-- tests/flytekit/unit/extend/test_task_executor.py | 6 +++--- tests/flytekit/unit/sensor/test_sensor_engine.py | 6 +++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 23f2ecb955..54c1ab3eac 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -15,6 +15,7 @@ SENSOR_MODULE = "sensor_module" SENSOR_NAME = "sensor_name" SENSOR_CONFIG_PKL = "sensor_config_pkl" +SENSOR_TYPE = "sensor" INPUTS = "inputs" diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 79d2e0f4b4..02edea96bb 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -19,7 +19,7 @@ from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate -from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME, SENSOR_TYPE T = typing.TypeVar("T") @@ -52,7 +52,7 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None inputs = meta.get(INPUTS, {}) - cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING + cur_state = SUCCEEDED if await sensor_def(SENSOR_TYPE, config=sensor_config).poke(**inputs) else RUNNING return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py index 702eede92b..b1fa93354c 100644 --- a/tests/flytekit/unit/extend/test_task_executor.py +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -6,7 +6,7 @@ from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource from flytekit import FlyteContextManager -from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, ExternalApiTask +from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, TASK_TYPE, ExternalApiTask from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentRegistry @@ -39,7 +39,7 @@ async def test_task_executor_engine(): inputs=collections.OrderedDict({"input": str, "kwargs": None}), outputs=collections.OrderedDict({"o0": str}), ) - tmp = get_task_template("api_task") + tmp = get_task_template(TASK_TYPE) tmp._custom = { TASK_MODULE: MockExternalApiTask.__module__, TASK_NAME: MockExternalApiTask.__name__, @@ -54,7 +54,7 @@ async def test_task_executor_engine(): ) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("api_task") + agent = AgentRegistry.get_agent(TASK_TYPE) res = await agent.async_do(ctx, tmp, task_inputs) assert res.resource.state == SUCCEEDED diff --git a/tests/flytekit/unit/sensor/test_sensor_engine.py b/tests/flytekit/unit/sensor/test_sensor_engine.py index dbb81c3f47..3078f2e6a0 100644 --- a/tests/flytekit/unit/sensor/test_sensor_engine.py +++ b/tests/flytekit/unit/sensor/test_sensor_engine.py @@ -10,7 +10,7 @@ from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models import literals, types from flytekit.sensor import FileSensor -from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME +from flytekit.sensor.base_sensor import SENSOR_MODULE, SENSOR_NAME, SENSOR_TYPE from tests.flytekit.unit.extend.test_agent import get_task_template @@ -22,7 +22,7 @@ async def test_sensor_engine(): }, {}, ) - tmp = get_task_template("sensor") + tmp = get_task_template(SENSOR_TYPE) tmp._custom = { SENSOR_MODULE: FileSensor.__module__, SENSOR_NAME: FileSensor.__name__, @@ -37,7 +37,7 @@ async def test_sensor_engine(): }, ) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("sensor") + agent = AgentRegistry.get_agent(SENSOR_TYPE) res = await agent.async_create(ctx, "/tmp", tmp, task_inputs) From c273689c64576355093454654806d5f4bb6c8011 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 10 Nov 2023 16:35:02 +0800 Subject: [PATCH 45/64] support more arguments in do task agent Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 14 ++++++++++++-- flytekit/extend/backend/base_agent.py | 8 ++++++-- flytekit/extend/backend/task_executor.py | 3 ++- .../flytekitplugins/chatgpt/task.py | 4 ++-- tests/flytekit/unit/extend/test_agent.py | 8 ++++++-- tests/flytekit/unit/extend/test_task_executor.py | 5 +++-- 6 files changed, 31 insertions(+), 11 deletions(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index c22e3e38fd..03f54a25ec 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -136,7 +136,17 @@ async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) + logger.info(f"{tmp.type} agent start doing the job") if agent.asynchronous: - return await agent.async_do(context=context, inputs=inputs, task_template=tmp) - return await asyncio.get_running_loop().run_in_executor(None, agent.do, context, "", inputs, tmp) + return await agent.async_do( + context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + ) + return await asyncio.get_running_loop().run_in_executor( + None, + agent.do, + context, + request.output_prefix, + tmp, + inputs, + ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index c24c5c80b9..43a30f7c80 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -88,6 +88,7 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT def do( self, context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -125,6 +126,7 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes async def async_do( self, context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -251,10 +253,12 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): inputs = self.get_input_literal_map(inputs) + output_prefix = self._ctx.file_access.get_random_local_directory() + if self._agent.asynchronous: - res = await self._agent.async_do(self._grpc_ctx, task_template, inputs) + res = await self._agent.async_do(self._grpc_ctx, output_prefix, task_template, inputs) else: - res = self._agent.do(self._grpc_ctx, task_template, inputs) + res = self._agent.do(self._grpc_ctx, output_prefix, task_template, inputs) return res def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index e8510362fa..a6357daad4 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -31,8 +31,9 @@ def __init__(self): async def async_do( self, context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, - inputs: Optional[LiteralMap] = None, + inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index faec04957b..197639bc2f 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict import openai @@ -43,9 +44,8 @@ async def do( openai.api_key = get_agent_secret(secret_key="FLYTE_OPENAI_ACCESS_TOKEN") self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - self._chatgpt_conf["timeout"] = TIMEOUT_SECONDS - completion = await openai.ChatCompletion.acreate(**self._chatgpt_conf) + completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 16dbac3da3..c9eeacd2b3 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -199,8 +199,12 @@ async def run_agent_server(): async_create_request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) - do_request = DoTaskRequest(inputs=task_inputs.to_flyte_idl(), template=dummy_template.to_flyte_idl()) - async_do_request = DoTaskRequest(inputs=task_inputs.to_flyte_idl(), template=async_dummy_template.to_flyte_idl()) + do_request = DoTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) + async_do_request = DoTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + ) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py index b1fa93354c..a491023384 100644 --- a/tests/flytekit/unit/extend/test_task_executor.py +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -5,7 +5,7 @@ import pytest from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource -from flytekit import FlyteContextManager +from flytekit import FlyteContext, FlyteContextManager from flytekit.core.external_api_task import TASK_MODULE, TASK_NAME, TASK_TYPE, ExternalApiTask from flytekit.core.interface import Interface, transform_interface_to_typed_interface from flytekit.core.type_engine import TypeEngine @@ -52,11 +52,12 @@ async def test_task_executor_engine(): "input": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(string_value="TASK INPUT"))), }, ) + output_prefix = FlyteContext.current_context().file_access.get_random_local_directory() ctx = MagicMock(spec=grpc.ServicerContext) agent = AgentRegistry.get_agent(TASK_TYPE) - res = await agent.async_do(ctx, tmp, task_inputs) + res = await agent.async_do(ctx, output_prefix, tmp, task_inputs) assert res.resource.state == SUCCEEDED assert ( res.resource.outputs From 232c80cbd96fa948228428f7d2aa95793d3f7dd5 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 15 Nov 2023 18:15:13 +0800 Subject: [PATCH 46/64] add annotations Signed-off-by: Future Outlier --- .../flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 197639bc2f..38150cd21e 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -16,6 +16,10 @@ class ChatGPTTask(ExternalApiTask): """ This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. + + Args: + openai_organization: OpenAI Organization. Config string can be found here. https://platform.openai.com/docs/api-reference/organization-optional + chatgpt_conf: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create """ _openai_organization: str = None From f842fdd97c5f4a6005ec8060e2a1bb955c59ff26 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 15 Nov 2023 23:42:38 +0800 Subject: [PATCH 47/64] move argument openai_organization to optional Signed-off-by: Future Outlier --- flytekit/extend/backend/task_executor.py | 1 - .../flytekitplugins/chatgpt/task.py | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index a6357daad4..c7b994ac9b 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -1,6 +1,5 @@ import importlib import typing -from typing import Optional import grpc import jsonpickle diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 38150cd21e..3b5a95d809 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict +from typing import Any, Dict, Optional import openai from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource @@ -18,24 +18,23 @@ class ChatGPTTask(ExternalApiTask): This is the simplest form of a ChatGPTTask Task, you can define the model and the input you want. Args: - openai_organization: OpenAI Organization. Config string can be found here. https://platform.openai.com/docs/api-reference/organization-optional + openai_organization: OpenAI Organization. String can be found here. https://platform.openai.com/docs/api-reference/organization-optional chatgpt_conf: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create """ - _openai_organization: str = None + _openai_organization: Optional[str] = None _chatgpt_conf: Dict[str, Any] = None def __init__(self, name: str, config: Dict[str, Any], **kwargs): - if "openai_organization" not in config: - raise ValueError("The 'openai_organization' configuration variable is required") - if "chatgpt_conf" not in config: raise ValueError("The 'chatgpt_conf' configuration variable is required") if "model" not in config["chatgpt_conf"]: raise ValueError("The 'model' configuration variable in 'chatgpt_conf' is required") - self._openai_organization = config["openai_organization"] + if "openai_organization" in config: + self._openai_organization = config["openai_organization"] + self._chatgpt_conf = config["chatgpt_conf"] super().__init__(name=name, config=config, return_type=str, **kwargs) From 835c48af724910976ac3ebbd56ac6a6b7711354d Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 17 Nov 2023 13:20:05 +0800 Subject: [PATCH 48/64] SyncAgentServiceServicer Signed-off-by: Future Outlier --- flytekit/clis/sdk_in_container/serve.py | 8 +- flytekit/core/external_api_task.py | 4 +- flytekit/extend/backend/agent_service.py | 4 +- flytekit/extend/backend/base_agent.py | 55 +++++++++-- tests/flytekit/unit/extend/test_agent.py | 115 ++++++++++++++--------- 5 files changed, 130 insertions(+), 56 deletions(-) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 145dc90212..6167664165 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,7 +1,10 @@ from concurrent import futures import click -from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server +from flyteidl.service.agent_pb2_grpc import ( + add_AsyncAgentServiceServicer_to_server, + add_SyncAgentServiceServicer_to_server, +) from grpc import aio _serve_help = """Start a grpc server for the agent service.""" @@ -42,7 +45,7 @@ def serve(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AsyncAgentService + from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService try: from prometheus_client import start_http_server @@ -54,6 +57,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) + add_SyncAgentServiceServicer_to_server(SyncAgentService(), server) server.add_insecure_port(f"[::]:{port}") await server.start() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index aa11accf48..2550af6527 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -10,7 +10,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import SYNC_PLUGIN, AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import SYNC_PLUGIN, SyncAgentExecutorMixin T = TypeVar("T") TASK_MODULE = "task_module" @@ -19,7 +19,7 @@ TASK_TYPE = "api_task" -class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): +class ExternalApiTask(SyncAgentExecutorMixin, PythonTask): """ Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a response from an external service. When the response is received, the task will complete. External API tasks are diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 03f54a25ec..9b54597e88 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -12,7 +12,7 @@ GetTaskRequest, GetTaskResponse, ) -from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer +from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer, SyncAgentServiceServicer from prometheus_client import Counter, Summary from flytekit import logger @@ -131,6 +131,8 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon return await agent.async_delete(context=context, resource_meta=request.resource_meta) return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta) + +class SyncAgentService(SyncAgentServiceServicer): @agent_exception_handler async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: tmp = TaskTemplate.from_flyte_idl(request.template) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 43a30f7c80..4f0339310e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -194,7 +194,7 @@ def _get_grpc_context() -> grpc.ServicerContext: class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. - Task should inherit from this class if the task can be run in the agent. + Asynchronous task should inherit from this class if the task can be run in the agent. """ _clean_up_task: coroutine = None @@ -204,18 +204,14 @@ class AsyncAgentExecutorMixin: _grpc_ctx: grpc.ServicerContext = _get_grpc_context() def execute(self, **kwargs) -> typing.Any: - from flytekit.extend.backend.task_executor import TaskExecutor # This is for circular import avoidance. from flytekit.tools.translator import get_serializable self._entity = typing.cast(PythonTask, self) task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template self._agent = AgentRegistry.get_agent(task_template.type) - if isinstance(self._agent, TaskExecutor): - res = asyncio.run(self._do(task_template, kwargs)) - else: - res = asyncio.run(self._create(task_template, kwargs)) - res = asyncio.run(self._get(resource_meta=res.resource_meta)) + res = asyncio.run(self._create(task_template, kwargs)) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) if res.resource.state != SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self._entity.name}") @@ -277,3 +273,48 @@ def get_input_literal_map(self, inputs: typing.Dict[str, typing.Any] = None) -> for k, v in inputs.items(): literals[k] = TypeEngine.to_literal(self._ctx, v, type(v), self._entity.interface.inputs[k].type) return LiteralMap(literals) if literals else None + + +class SyncAgentExecutorMixin: + """ + This mixin class is used to run the agent task locally, and it's only used for local execution. + Synchronous task should inherit from this class if the task can be run in the agent. + """ + + _agent: AgentBase = None + _entity: PythonTask = None + _ctx: FlyteContext = FlyteContext.current_context() + _grpc_ctx: grpc.ServicerContext = _get_grpc_context() + + def execute(self, **kwargs) -> typing.Any: + from flytekit.tools.translator import get_serializable + + self._entity = typing.cast(PythonTask, self) + task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template + self._agent = AgentRegistry.get_agent(task_template.type) + + res = asyncio.run(self._do(task_template, kwargs)) + + if res.resource.state != SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self._entity.name}") + + return LiteralMap.from_flyte_idl(res.resource.outputs) + + async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): + inputs = self.get_input_literal_map(inputs) + output_prefix = self._ctx.file_access.get_random_local_directory() + + if self._agent.asynchronous: + res = await self._agent.async_do(self._grpc_ctx, output_prefix, task_template, inputs) + else: + res = self._agent.do(self._grpc_ctx, output_prefix, task_template, inputs) + return res + + def get_input_literal_map(self, inputs: typing.Dict[str, typing.Any] = None) -> typing.Optional[LiteralMap]: + if inputs is None: + return None + # Convert python inputs to literals + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(self._ctx, v, type(v), self._entity.interface.inputs[k].type) + return LiteralMap(literals) if literals else None diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index c9eeacd2b3..cbd7d5f5cb 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -25,12 +25,14 @@ import flytekit.models.interface as interface_models from flytekit import PythonFunctionTask -from flytekit.extend.backend.agent_service import AsyncAgentService +from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( ASYNC_PLUGIN, + SYNC_PLUGIN, AgentBase, AgentRegistry, AsyncAgentExecutorMixin, + SyncAgentExecutorMixin, convert_to_flyte_state, get_agent_secret, is_terminal_state, @@ -49,28 +51,23 @@ class Metadata: job_id: str -class DummyAgent(AgentBase): +class SyncDummyAgent(AgentBase): def __init__(self): - super().__init__(task_type="dummy", asynchronous=False) + super().__init__(task_type="sync_dummy", asynchronous=True) - def create( + async def async_do( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> CreateTaskResponse: - return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) - - def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: - return GetTaskResponse(resource=Resource(state=SUCCEEDED)) - - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: - return DeleteTaskResponse() + ) -> DoTaskResponse: + return DoTaskResponse(resource=Resource(state=SUCCEEDED)) def do( self, context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> DoTaskResponse: @@ -96,13 +93,20 @@ async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) - async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: return DeleteTaskResponse() - async def async_do( + def create( self, context: grpc.ServicerContext, + output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> DoTaskResponse: - return DoTaskResponse(resource=Resource(state=SUCCEEDED)) + ) -> CreateTaskResponse: + return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + + def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + return DeleteTaskResponse() def get_task_template(task_type: str) -> TaskTemplate: @@ -143,29 +147,42 @@ def get_task_template(task_type: str) -> TaskTemplate: }, ) -dummy_template = get_task_template("dummy") + async_dummy_template = get_task_template("async_dummy") +sync_dummy_template = get_task_template("sync_dummy") def test_dummy_agent(): - AgentRegistry.register(DummyAgent()) + ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("dummy") + async_agent = AgentRegistry.get_agent("async_dummy") + sync_agent = AgentRegistry.get_agent("sync_dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes - assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED - assert agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() - assert agent.do(ctx, dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) + assert async_agent.create(ctx, "/tmp", async_dummy_template, task_inputs).resource_meta == metadata_bytes + assert async_agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED + assert async_agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() + assert sync_agent.do(ctx, sync_dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) - class DummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): + class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( - task_type="dummy", + task_type="async_dummy", runtime_flavor=ASYNC_PLUGIN, **kwargs, ) - t = DummyTask(task_config={}, task_function=lambda: None, container_image="dummy") + t = AsyncDummyTask(task_config={}, task_function=lambda: None, container_image="dummy") + t.execute() + + class SyncDummyTask(SyncAgentExecutorMixin, PythonFunctionTask): + def __init__(self, **kwargs): + super().__init__( + task_type="sync_dummy", + runtime_flavor=SYNC_PLUGIN, + **kwargs, + ) + + t = SyncDummyTask(task_config={}, task_function=lambda: None, container_image="sync_dummy") t.execute() t._task_type = "non-exist-type" @@ -175,58 +192,64 @@ def __init__(self, **kwargs): @pytest.mark.asyncio async def test_async_dummy_agent(): - AgentRegistry.register(AsyncDummyAgent()) ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent("async_dummy") + async_agent = AgentRegistry.get_agent("async_dummy") + sync_agent = AgentRegistry.get_agent("sync_dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) + res = await async_agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) assert res.resource_meta == metadata_bytes - res = await agent.async_get(ctx, metadata_bytes) + res = await async_agent.async_get(ctx, metadata_bytes) assert res.resource.state == SUCCEEDED - res = await agent.async_delete(ctx, metadata_bytes) + res = await async_agent.async_delete(ctx, metadata_bytes) assert res == DeleteTaskResponse() - res = await agent.async_do(ctx, async_dummy_template, task_inputs) + res = await sync_agent.async_do(ctx, "/tmp", sync_dummy_template, task_inputs) assert res == DoTaskResponse(resource=Resource(state=SUCCEEDED)) @pytest.mark.asyncio async def run_agent_server(): - service = AsyncAgentService() + async_agent_service = AsyncAgentService() + sync_agent_service = SyncAgentService() + ctx = MagicMock(spec=grpc.ServicerContext) create_request = CreateTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) async_create_request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() ) async_do_request = DoTaskRequest( - inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=sync_dummy_template.to_flyte_idl() ) fake_agent = "fake" metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - res = await service.CreateTask(create_request, ctx) + res = await async_agent_service.CreateTask(create_request, ctx) assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.DeleteTask( + DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx + ) assert isinstance(res, DeleteTaskResponse) - res = await service.DoTask(do_request, ctx) + res = await sync_agent_service.DoTask(do_request, ctx) assert res.resource.state == SUCCEEDED - res = await service.CreateTask(async_create_request, ctx) + res = await async_agent_service.CreateTask(async_create_request, ctx) assert res.resource_meta == metadata_bytes - res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) assert res.resource.state == SUCCEEDED - res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + res = await async_agent_service.DeleteTask( + DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx + ) assert isinstance(res, DeleteTaskResponse) - res = await service.DoTask(async_do_request, ctx) + res = await sync_agent_service.DoTask(async_do_request, ctx) assert res.resource.state == SUCCEEDED - res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) + res = await async_agent_service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) assert res is None @@ -293,3 +316,7 @@ def get_task_template(task_type: str) -> TaskTemplate: type=task_type, custom={}, ) + + +AgentRegistry.register(AsyncDummyAgent()) +AgentRegistry.register(SyncDummyAgent()) From d1e99be220bfc4211d5ce9a9ca34134b5988479d Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Mon, 20 Nov 2023 14:25:56 +0800 Subject: [PATCH 49/64] make organization as a required arguement back Signed-off-by: Future Outlier --- .../flytekitplugins/chatgpt/task.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 3b5a95d809..0bac5c9c70 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, Optional +from typing import Any, Dict import openai from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource @@ -22,19 +22,20 @@ class ChatGPTTask(ExternalApiTask): chatgpt_conf: ChatGPT job configuration. Config structure can be found here. https://platform.openai.com/docs/api-reference/completions/create """ - _openai_organization: Optional[str] = None + _openai_organization: str = None _chatgpt_conf: Dict[str, Any] = None def __init__(self, name: str, config: Dict[str, Any], **kwargs): + if "openai_organization" not in config: + raise ValueError("The 'openai_organization' configuration variable is required") + if "chatgpt_conf" not in config: raise ValueError("The 'chatgpt_conf' configuration variable is required") if "model" not in config["chatgpt_conf"]: raise ValueError("The 'model' configuration variable in 'chatgpt_conf' is required") - if "openai_organization" in config: - self._openai_organization = config["openai_organization"] - + self._openai_organization = config["openai_organization"] self._chatgpt_conf = config["chatgpt_conf"] super().__init__(name=name, config=config, return_type=str, **kwargs) From b0670d33833be1d207f26b3fd289f9dd08537463 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 21 Nov 2023 22:31:51 +0800 Subject: [PATCH 50/64] use is_sync_plugin varaible Signed-off-by: Future Outlier --- flytekit/core/base_sql_task.py | 4 +- flytekit/core/base_task.py | 15 ++++--- flytekit/core/external_api_task.py | 4 +- flytekit/core/python_auto_container.py | 4 +- flytekit/core/python_function_task.py | 4 +- flytekit/extend/backend/base_agent.py | 3 -- flytekit/models/task.py | 25 +++++++++-- flytekit/sensor/base_sensor.py | 4 +- .../flytekitplugins/airflow/task.py | 4 +- .../flytekitplugins/awsbatch/task.py | 3 +- .../flytekitplugins/bigquery/task.py | 4 +- .../flytekitplugins/mmcloud/task.py | 3 +- .../flytekitplugins/snowflake/task.py | 4 +- .../flytekitplugins/spark/task.py | 3 +- .../flytekit/unit/core/test_task_metadata.py | 10 +++-- tests/flytekit/unit/extend/test_agent.py | 44 +++---------------- .../unit/extend/test_task_executor.py | 2 +- 17 files changed, 60 insertions(+), 80 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 0d6295db88..954846d7a7 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -27,7 +27,7 @@ def __init__( inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, outputs: Optional[Dict[str, Type]] = None, - runtime_flavor: Optional[str] = None, + is_sync_plugin: bool = False, **kwargs, ): """ @@ -40,7 +40,7 @@ def __init__( interface=Interface(inputs=inputs or {}, outputs=outputs or {}), metadata=metadata, task_config=task_config, - runtime_flavor=runtime_flavor, + is_sync_plugin=is_sync_plugin, **kwargs, ) self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 50d29f2350..dc54344dd1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -105,7 +105,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None - runtime_flavor: Optional[str] = None + is_sync_plugin: bool = False def __post_init__(self): if self.timeout: @@ -133,7 +133,8 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: runtime=_task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - self.runtime_flavor, + "python", + self.is_sync_plugin, ), timeout=self.timeout, retries=self.retry_strategy, @@ -173,13 +174,13 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, - runtime_flavor: Optional[str] = None, + is_sync_plugin: bool = False, **kwargs, ): self._task_type = task_type self._name = name self._interface = interface - self._metadata = metadata if metadata else TaskMetadata(runtime_flavor=runtime_flavor) + self._metadata = metadata if metadata else TaskMetadata(is_sync_plugin=is_sync_plugin) self._task_type_version = task_type_version self._security_ctx = security_ctx self._docs = docs @@ -423,7 +424,7 @@ def __init__( environment: Optional[Dict[str, str]] = None, disable_deck: Optional[bool] = None, enable_deck: Optional[bool] = None, - runtime_flavor: Optional[str] = None, + is_sync_plugin: bool = False, **kwargs, ): """ @@ -439,13 +440,13 @@ def __init__( execution of the task. Supplied as a dictionary of key/value pairs disable_deck (bool): (deprecated) If true, this task will not output deck html file enable_deck (bool): If true, this task will output deck html file - runtime_flavor (Optional[str]): we can set it to "sync_plugin" or "async_plugin" for flytepropeller to execute plugin task + is_sync_plugin (bool): If true, plugin task will execute synchronously. """ super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), - runtime_flavor=runtime_flavor, + is_sync_plugin=is_sync_plugin, **kwargs, ) self._python_interface = interface if interface else Interface() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 2550af6527..056d0fc3ce 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -10,7 +10,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import SYNC_PLUGIN, SyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin T = TypeVar("T") TASK_MODULE = "task_module" @@ -48,7 +48,7 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), - runtime_flavor=SYNC_PLUGIN, + is_sync_plugin=True, **kwargs, ) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 87fc6f7088..fe6a15efb1 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -48,7 +48,7 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, - runtime_flavor: Optional[str] = None, + is_sync_plugin: bool = False, **kwargs, ): """ @@ -93,7 +93,7 @@ def __init__( name=name, task_config=task_config, security_ctx=sec_ctx, - runtime_flavor=runtime_flavor, + is_sync_plugin=is_sync_plugin, **kwargs, ) self._container_image = container_image diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index cba7f91629..b0e35803d0 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -102,7 +102,7 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, - runtime_flavor: Optional[str] = None, + is_sync_plugin: bool = False, **kwargs, ): """ @@ -125,7 +125,7 @@ def __init__( interface=mutated_interface, task_config=task_config, task_resolver=task_resolver, - runtime_flavor=runtime_flavor, + is_sync_plugin=is_sync_plugin, **kwargs, ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 4f0339310e..1c14a6920a 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -31,9 +31,6 @@ from flytekit.exceptions.user import FlyteUserException from flytekit.models.literals import LiteralMap -SYNC_PLUGIN = "sync_plugin" # Indicates that the sync plugin in FlytePropeller should be used to run this task -ASYNC_PLUGIN = "async_plugin" # Indicates that the async plugin in FlytePropeller should be used to run this task - class AgentBase(ABC): """ diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 4d8d386c86..314d42489f 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -116,16 +116,18 @@ class RuntimeType(object): OTHER = 0 FLYTE_SDK = 1 - def __init__(self, type, version, flavor): + def __init__(self, type, version, flavor, is_sync_plugin): """ :param int type: Enum type from RuntimeMetadata.RuntimeType :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in Admin or Propeller - :param Text flavor: Optional extra information about the plugin type (e.g. async plugin, sync plugin... etc.). + :param Text flavor: Optional extra information about runtime environment (e.g. Python, GoLang, etc.) + :param Boolean is_sync_plugin: Boolean to indicate if the plugin is sync or async """ self._type = type self._version = version self._flavor = flavor + self._is_sync_plugin = is_sync_plugin @property def type(self): @@ -151,11 +153,21 @@ def flavor(self): """ return self._flavor + @property + def is_sync_plugin(self): + """ + Boolean to indicate if the plugin is sync or async + :rtype: Boolean + """ + return self._is_sync_plugin + def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata """ - return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) + return _core_task.RuntimeMetadata( + type=self.type, version=self.version, flavor=self.flavor, is_sync_plugin=self.is_sync_plugin + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -163,7 +175,12 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.tasks_pb2.RuntimeMetadata pb2_object: :rtype: RuntimeMetadata """ - return cls(type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor) + return cls( + type=pb2_object.type, + version=pb2_object.version, + flavor=pb2_object.flavor, + is_sync_plugin=pb2_object.is_sync_plugin, + ) class TaskMetadata(_common.FlyteIdlEntity): diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 54c1ab3eac..19439250a1 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -9,7 +9,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin T = TypeVar("T") SENSOR_MODULE = "sensor_module" @@ -45,7 +45,7 @@ def __init__( name=name, task_config=None, interface=Interface(inputs=inputs), - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) self._sensor_config = sensor_config diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index ebd1ba43c2..14499ad894 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -12,7 +12,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @dataclass @@ -39,7 +39,7 @@ def __init__( query_template=query_template, interface=Interface(inputs=inputs or {}), task_type=self._TASK_TYPE, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 2438613d8a..7beb59de46 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -8,7 +8,6 @@ from flytekit import PythonFunctionTask from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN @dataclass @@ -45,7 +44,7 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs ) self._task_config = task_config diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 8268602022..2b686d01a2 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -7,7 +7,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -63,7 +63,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) self._output_structured_dataset_type = output_structured_dataset_type diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py index d3d682b964..e826819b69 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py @@ -10,7 +10,6 @@ from flytekit.core.python_function_task import PythonFunctionTask from flytekit.core.resources import Resources from flytekit.extend import TaskPlugins -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.image_spec.image_spec import ImageSpec @@ -41,7 +40,7 @@ def __init__( task_type=self._TASK_TYPE, task_function=task_function, container_image=container_image, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index e3eaa9699a..174650d8fa 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,7 +3,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN, AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -77,7 +77,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index b6c84fa687..eb991bd008 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,7 +9,6 @@ from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins -from flytekit.extend.backend.base_agent import ASYNC_PLUGIN from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -129,7 +128,7 @@ def __init__( task_type=self._SPARK_TASK_TYPE, task_function=task_function, container_image=container_image, - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 6214c3f9a6..35fa1f00df 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -41,7 +41,7 @@ def test_to_task_metadata_model(): retries=3, timeout=3600, pod_template_name="TEST POD TEMPLATE NAME", - runtime_flavor="sync_plugin", + is_sync_plugin=True, ) model = tm.to_taskmetadata_model() @@ -49,7 +49,8 @@ def test_to_task_metadata_model(): assert model.runtime == _task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - "sync_plugin", + "python", + True, ) assert model.retries == _literal_models.RetryStrategy(3) assert model.timeout == datetime.timedelta(seconds=3600) @@ -59,11 +60,12 @@ def test_to_task_metadata_model(): assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" - # Since the default value is not "python" anymore, add a test to test the default value + # Test the default value of is_sync_plugin is False tm = TaskMetadata() model = tm.to_taskmetadata_model() assert model.runtime == _task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - None, + "python", + False, ) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index cbd7d5f5cb..dafffdb911 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -27,8 +27,6 @@ from flytekit import PythonFunctionTask from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( - ASYNC_PLUGIN, - SYNC_PLUGIN, AgentBase, AgentRegistry, AsyncAgentExecutorMixin, @@ -109,13 +107,13 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT return DeleteTaskResponse() -def get_task_template(task_type: str) -> TaskTemplate: +def get_task_template(task_type: str, is_sync_plugin: bool = False) -> TaskTemplate: task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version" ) task_metadata = task.TaskMetadata( True, - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python", is_sync_plugin), timedelta(days=1), literals.RetryStrategy(3), True, @@ -149,7 +147,7 @@ def get_task_template(task_type: str) -> TaskTemplate: async_dummy_template = get_task_template("async_dummy") -sync_dummy_template = get_task_template("sync_dummy") +sync_dummy_template = get_task_template("sync_dummy", True) def test_dummy_agent(): @@ -167,7 +165,7 @@ class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="async_dummy", - runtime_flavor=ASYNC_PLUGIN, + is_sync_plugin=False, **kwargs, ) @@ -178,7 +176,7 @@ class SyncDummyTask(SyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="sync_dummy", - runtime_flavor=SYNC_PLUGIN, + is_sync_plugin=True, **kwargs, ) @@ -286,37 +284,5 @@ def test_get_agent_secret(mocked_context): assert get_agent_secret("mocked key") == "mocked token" -def get_task_template(task_type: str) -> TaskTemplate: - task_id = Identifier( - resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version" - ) - task_metadata = task.TaskMetadata( - True, - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), - timedelta(days=1), - literals.RetryStrategy(3), - True, - "0.1.1b0", - "This is deprecated!", - True, - "A", - ) - - interfaces = interface_models.TypedInterface( - { - "a": interface_models.Variable(types.LiteralType(types.SimpleType.INTEGER), "description1"), - }, - {}, - ) - - return TaskTemplate( - id=task_id, - metadata=task_metadata, - interface=interfaces, - type=task_type, - custom={}, - ) - - AgentRegistry.register(AsyncDummyAgent()) AgentRegistry.register(SyncDummyAgent()) diff --git a/tests/flytekit/unit/extend/test_task_executor.py b/tests/flytekit/unit/extend/test_task_executor.py index a491023384..a1a9a277dd 100644 --- a/tests/flytekit/unit/extend/test_task_executor.py +++ b/tests/flytekit/unit/extend/test_task_executor.py @@ -39,7 +39,7 @@ async def test_task_executor_engine(): inputs=collections.OrderedDict({"input": str, "kwargs": None}), outputs=collections.OrderedDict({"o0": str}), ) - tmp = get_task_template(TASK_TYPE) + tmp = get_task_template(TASK_TYPE, True) tmp._custom = { TASK_MODULE: MockExternalApiTask.__module__, TASK_NAME: MockExternalApiTask.__name__, From 104ba6f5f9ec3c598473838e7eb633599e1cea43 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 21 Nov 2023 22:36:48 +0800 Subject: [PATCH 51/64] ruff fmt Signed-off-by: Future Outlier --- plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py | 2 +- tests/flytekit/unit/extend/test_agent.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 7beb59de46..33328e5bfb 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -45,7 +45,7 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, is_sync_plugin=False, - **kwargs + **kwargs, ) self._task_config = task_config diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index dafffdb911..764c9d40b8 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -151,7 +151,6 @@ def get_task_template(task_type: str, is_sync_plugin: bool = False) -> TaskTempl def test_dummy_agent(): - ctx = MagicMock(spec=grpc.ServicerContext) async_agent = AgentRegistry.get_agent("async_dummy") sync_agent = AgentRegistry.get_agent("sync_dummy") From ec357144d94e7be4b119b9e121c2dc33f37b3a55 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 22 Nov 2023 13:01:03 +0800 Subject: [PATCH 52/64] add todo Signed-off-by: Future Outlier --- flytekit/core/base_task.py | 2 ++ flytekit/extend/backend/agent_service.py | 4 ++++ flytekit/models/task.py | 9 +++++++-- .../flytekitplugins/chatgpt/task.py | 4 ++-- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index fd8f70b75b..426b603278 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -181,6 +181,8 @@ def __init__( self._task_type = task_type self._name = name self._interface = interface + # agent_metadata = agent_metadata(is_sync_plugin=is_sync_plugin) + # self._metadata = metadata if metadata else TaskMetadata(agent_metadata=agent_metadata) self._metadata = metadata if metadata else TaskMetadata(is_sync_plugin=is_sync_plugin) self._task_type_version = task_type_version self._security_ctx = security_ctx diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 9b54597e88..cf96d77dcf 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -97,7 +97,9 @@ async def wrapper( class AsyncAgentService(AsyncAgentServiceServicer): @agent_exception_handler async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: + print("create task request", request) tmp = TaskTemplate.from_flyte_idl(request.template) + print("create task template", tmp) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) @@ -135,7 +137,9 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon class SyncAgentService(SyncAgentServiceServicer): @agent_exception_handler async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: + print("do task request", request) tmp = TaskTemplate.from_flyte_idl(request.template) + print("do task template", tmp) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 314d42489f..90498bff22 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -168,7 +168,7 @@ def to_flyte_idl(self): return _core_task.RuntimeMetadata( type=self.type, version=self.version, flavor=self.flavor, is_sync_plugin=self.is_sync_plugin ) - + # TODO: use hasField to check agent metadata @classmethod def from_flyte_idl(cls, pb2_object): """ @@ -179,7 +179,9 @@ def from_flyte_idl(cls, pb2_object): type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor, - is_sync_plugin=pb2_object.is_sync_plugin, + # is_sync_plugin=True, + is_sync_plugin=pb2_object.is_sync_plugin if pb2_object.agent_metadata else False, + # is_sync_plugin=pb2_object.is_sync_plugin, ) @@ -326,6 +328,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.task_pb2.TaskMetadata pb2_object: :rtype: TaskMetadata """ + print("@@@ pb2_object.runtime", pb2_object.runtime) + print(pb2_object.runtime.is_sync_plugin) + print("@@@ RuntimeMetadata.from_flyte_idl(pb2_object.runtime)", RuntimeMetadata.from_flyte_idl(pb2_object.runtime)) return cls( discoverable=pb2_object.discoverable, runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 0bac5c9c70..666013ce0d 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -49,8 +49,8 @@ async def do( self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) - message = completion.choices[0].message.content + # completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) + # message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() outputs = LiteralMap( From c46399240e1eef877a29a81dcabe467081bf28da Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 24 Nov 2023 12:17:59 +0800 Subject: [PATCH 53/64] use plugin_metadata by default Signed-off-by: Future Outlier --- flytekit/core/base_sql_task.py | 2 +- flytekit/core/base_task.py | 14 ++++++------- flytekit/core/python_auto_container.py | 2 +- flytekit/core/python_function_task.py | 2 +- flytekit/extend/backend/agent_service.py | 4 ---- flytekit/models/task.py | 21 +++++++++---------- .../flytekitplugins/chatgpt/task.py | 4 ++-- .../flytekit/unit/core/test_task_metadata.py | 20 ++++++++++++++---- 8 files changed, 38 insertions(+), 31 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index 954846d7a7..aecbdb88c5 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -27,7 +27,7 @@ def __init__( inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, outputs: Optional[Dict[str, Type]] = None, - is_sync_plugin: bool = False, + is_sync_plugin: Optional[bool] = None, **kwargs, ): """ diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 426b603278..64fdddc59f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -26,6 +26,7 @@ from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flyteidl.core import tasks_pb2 +from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ( @@ -106,7 +107,7 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None - is_sync_plugin: bool = False + plugin_metadata: Optional[PluginMetadata] = None def __post_init__(self): if self.timeout: @@ -135,7 +136,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", - self.is_sync_plugin, + self.plugin_metadata, ), timeout=self.timeout, retries=self.retry_strategy, @@ -175,15 +176,14 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, - is_sync_plugin: bool = False, + is_sync_plugin: Optional[bool] = None, **kwargs, ): self._task_type = task_type self._name = name self._interface = interface - # agent_metadata = agent_metadata(is_sync_plugin=is_sync_plugin) - # self._metadata = metadata if metadata else TaskMetadata(agent_metadata=agent_metadata) - self._metadata = metadata if metadata else TaskMetadata(is_sync_plugin=is_sync_plugin) + plugin_metadata = PluginMetadata(is_sync_plugin=is_sync_plugin) if is_sync_plugin is not None else None + self._metadata = metadata if metadata else TaskMetadata(plugin_metadata=plugin_metadata) self._task_type_version = task_type_version self._security_ctx = security_ctx self._docs = docs @@ -427,7 +427,7 @@ def __init__( environment: Optional[Dict[str, str]] = None, disable_deck: Optional[bool] = None, enable_deck: Optional[bool] = None, - is_sync_plugin: bool = False, + is_sync_plugin: Optional[bool] = None, **kwargs, ): """ diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 55788b6333..bad94beead 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -48,7 +48,7 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, - is_sync_plugin: bool = False, + is_sync_plugin: Optional[bool] = None, **kwargs, ): """ diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index b0e35803d0..8edcbad578 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -102,7 +102,7 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, - is_sync_plugin: bool = False, + is_sync_plugin: Optional[bool] = None, **kwargs, ): """ diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index cf96d77dcf..9b54597e88 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -97,9 +97,7 @@ async def wrapper( class AsyncAgentService(AsyncAgentServiceServicer): @agent_exception_handler async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: - print("create task request", request) tmp = TaskTemplate.from_flyte_idl(request.template) - print("create task template", tmp) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) @@ -137,9 +135,7 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon class SyncAgentService(SyncAgentServiceServicer): @agent_exception_handler async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: - print("do task request", request) tmp = TaskTemplate.from_flyte_idl(request.template) - print("do task template", tmp) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 90498bff22..e9368e41ea 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -116,7 +116,7 @@ class RuntimeType(object): OTHER = 0 FLYTE_SDK = 1 - def __init__(self, type, version, flavor, is_sync_plugin): + def __init__(self, type, version, flavor, plugin_metadata): """ :param int type: Enum type from RuntimeMetadata.RuntimeType :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in @@ -127,7 +127,7 @@ def __init__(self, type, version, flavor, is_sync_plugin): self._type = type self._version = version self._flavor = flavor - self._is_sync_plugin = is_sync_plugin + self._plugin_metadata = plugin_metadata @property def type(self): @@ -154,21 +154,21 @@ def flavor(self): return self._flavor @property - def is_sync_plugin(self): + def plugin_metadata(self): """ Boolean to indicate if the plugin is sync or async :rtype: Boolean """ - return self._is_sync_plugin + return self._plugin_metadata def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata """ return _core_task.RuntimeMetadata( - type=self.type, version=self.version, flavor=self.flavor, is_sync_plugin=self.is_sync_plugin + type=self.type, version=self.version, flavor=self.flavor, plugin_metadata=self._plugin_metadata ) - # TODO: use hasField to check agent metadata + @classmethod def from_flyte_idl(cls, pb2_object): """ @@ -179,9 +179,7 @@ def from_flyte_idl(cls, pb2_object): type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor, - # is_sync_plugin=True, - is_sync_plugin=pb2_object.is_sync_plugin if pb2_object.agent_metadata else False, - # is_sync_plugin=pb2_object.is_sync_plugin, + plugin_metadata=pb2_object.plugin_metadata if pb2_object.HasField("plugin_metadata") else None, ) @@ -329,8 +327,9 @@ def from_flyte_idl(cls, pb2_object): :rtype: TaskMetadata """ print("@@@ pb2_object.runtime", pb2_object.runtime) - print(pb2_object.runtime.is_sync_plugin) - print("@@@ RuntimeMetadata.from_flyte_idl(pb2_object.runtime)", RuntimeMetadata.from_flyte_idl(pb2_object.runtime)) + print( + "@@@ RuntimeMetadata.from_flyte_idl(pb2_object.runtime)", RuntimeMetadata.from_flyte_idl(pb2_object.runtime) + ) return cls( discoverable=pb2_object.discoverable, runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 666013ce0d..0bac5c9c70 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -49,8 +49,8 @@ async def do( self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - # completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) - # message = completion.choices[0].message.content + completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) + message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() outputs = LiteralMap( diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 35fa1f00df..6f869f78d2 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -4,6 +4,7 @@ from flytekit import __version__ from flytekit.core.base_task import TaskMetadata +from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model @@ -32,6 +33,7 @@ def test_retry_strategy(): def test_to_task_metadata_model(): + # Test the value of is_sync_plugin is True tm = TaskMetadata( cache=True, cache_serialize=True, @@ -41,7 +43,7 @@ def test_to_task_metadata_model(): retries=3, timeout=3600, pod_template_name="TEST POD TEMPLATE NAME", - is_sync_plugin=True, + plugin_metadata=PluginMetadata(is_sync_plugin=True), ) model = tm.to_taskmetadata_model() @@ -50,7 +52,7 @@ def test_to_task_metadata_model(): _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", - True, + plugin_metadata=PluginMetadata(is_sync_plugin=True), ) assert model.retries == _literal_models.RetryStrategy(3) assert model.timeout == datetime.timedelta(seconds=3600) @@ -60,12 +62,22 @@ def test_to_task_metadata_model(): assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" - # Test the default value of is_sync_plugin is False + # Test the value of is_sync_plugin is False + tm = TaskMetadata(plugin_metadata=PluginMetadata(is_sync_plugin=False)) + model = tm.to_taskmetadata_model() + assert model.runtime == _task_model.RuntimeMetadata( + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "python", + plugin_metadata=PluginMetadata(is_sync_plugin=False), + ) + + # Test the default value of is_sync_plugin is None tm = TaskMetadata() model = tm.to_taskmetadata_model() assert model.runtime == _task_model.RuntimeMetadata( _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", - False, + None, ) From 868f81b5ab81c3107742ff1d7cac85e0048668ef Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Fri, 24 Nov 2023 12:31:12 +0800 Subject: [PATCH 54/64] remove print debug Signed-off-by: Future Outlier --- flytekit/models/task.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index e9368e41ea..ab37dad5dd 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -326,10 +326,6 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.task_pb2.TaskMetadata pb2_object: :rtype: TaskMetadata """ - print("@@@ pb2_object.runtime", pb2_object.runtime) - print( - "@@@ RuntimeMetadata.from_flyte_idl(pb2_object.runtime)", RuntimeMetadata.from_flyte_idl(pb2_object.runtime) - ) return cls( discoverable=pb2_object.discoverable, runtime=RuntimeMetadata.from_flyte_idl(pb2_object.runtime), From 6d15fdc15c3e0479d1a071a3aec6cbf059a73ffb Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 13 Dec 2023 16:29:59 +0800 Subject: [PATCH 55/64] move to rich click Signed-off-by: Future Outlier --- flytekit/clis/sdk_in_container/serve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 4005ded9a6..d82398ea81 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,6 +1,6 @@ from concurrent import futures -import click +import rich_click as click from flyteidl.service.agent_pb2_grpc import ( add_AsyncAgentServiceServicer_to_server, add_SyncAgentServiceServicer_to_server, From 510711ed90fc96f268043c0c7ba769ca4d92004b Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 13 Dec 2023 23:26:48 +0800 Subject: [PATCH 56/64] change to agentServicer, agentExecutorMixin, and use is_sync attribute Signed-off-by: Future Outlier --- flytekit/__init__.py | 2 - flytekit/clis/sdk_in_container/serve.py | 10 +--- flytekit/core/external_api_task.py | 12 ++-- flytekit/extend/backend/agent_service.py | 40 +++---------- flytekit/extend/backend/base_agent.py | 58 +++++-------------- flytekit/extend/backend/task_executor.py | 7 ++- flytekit/models/literals.py | 3 +- flytekit/sensor/base_sensor.py | 8 ++- flytekit/sensor/sensor_engine.py | 2 + .../flytekitplugins/airflow/task.py | 7 ++- .../flytekitplugins/awsbatch/task.py | 3 +- .../flytekitplugins/bigquery/task.py | 7 ++- .../flytekitplugins/mmcloud/task.py | 3 +- .../flytekitplugins/chatgpt/task.py | 6 +- .../flytekitplugins/snowflake/task.py | 7 ++- .../flytekitplugins/spark/task.py | 9 +-- .../flytekit/unit/core/test_task_metadata.py | 2 +- tests/flytekit/unit/extend/test_agent.py | 10 +--- 18 files changed, 73 insertions(+), 123 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index d070e1a0e7..bcb1d2c54e 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -248,8 +248,6 @@ from flytekit.extend.backend.task_executor import TaskExecutor # isort:skip. This is for circular import avoidance. -__version__ = "0.0.0+develop" - def current_context() -> ExecutionParameters: """ diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index d82398ea81..9d48ff4c4d 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,10 +1,7 @@ from concurrent import futures import rich_click as click -from flyteidl.service.agent_pb2_grpc import ( - add_AsyncAgentServiceServicer_to_server, - add_SyncAgentServiceServicer_to_server, -) +from flyteidl.service.agent_pb2_grpc import add_AgentServiceServicer_to_server from grpc import aio @@ -52,7 +49,7 @@ def agent(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService + from flytekit.extend.backend.agent_service import AgentService try: from prometheus_client import start_http_server @@ -63,8 +60,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting the agent service...", fg="blue") server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) - add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) - add_SyncAgentServiceServicer_to_server(SyncAgentService(), server) + add_AgentServiceServicer_to_server(AgentService(), server) server.add_insecure_port(f"[::]:{port}") await server.start() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 056d0fc3ce..c006b3b9ff 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -4,13 +4,13 @@ from typing import Any, Dict, Optional, TypeVar import jsonpickle -from flyteidl.admin.agent_pb2 import DoTaskResponse +from flyteidl.admin.agent_pb2 import CreateTaskResponse from typing_extensions import get_type_hints from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import SyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin T = TypeVar("T") TASK_MODULE = "task_module" @@ -19,13 +19,15 @@ TASK_TYPE = "api_task" -class ExternalApiTask(SyncAgentExecutorMixin, PythonTask): +class ExternalApiTask(AgentExecutorMixin, PythonTask): """ Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a response from an external service. When the response is received, the task will complete. External API tasks are designed to be run by the flyte agent. """ + is_sync = True + def __init__( self, name: str, @@ -48,14 +50,14 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), - is_sync_plugin=True, + is_sync_plugin=self.is_sync, **kwargs, ) self._config = config @abstractmethod - async def do(self, **kwargs) -> DoTaskResponse: + async def do(self, **kwargs) -> CreateTaskResponse: """ Initiate an HTTP request to an external service such as OpenAI or Vertex AI and retrieve the response. """ diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 9b54597e88..c72869debf 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -7,12 +7,10 @@ CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, - DoTaskRequest, - DoTaskResponse, GetTaskRequest, GetTaskResponse, ) -from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer, SyncAgentServiceServicer +from flyteidl.service.agent_pb2_grpc import AgentServiceServicer from prometheus_client import Counter, Summary from flytekit import logger @@ -25,7 +23,6 @@ create_operation = "create" get_operation = "get" delete_operation = "delete" -do_operation = "do" # Follow the naming convention. https://prometheus.io/docs/practices/naming/ request_success_count = Counter( @@ -47,7 +44,11 @@ def agent_exception_handler(func): async def wrapper( self, - request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest, DoTaskRequest], + request: typing.Union[ + CreateTaskRequest, + GetTaskRequest, + DeleteTaskRequest, + ], context: grpc.ServicerContext, *args, **kwargs, @@ -63,11 +64,6 @@ async def wrapper( elif isinstance(request, DeleteTaskRequest): task_type = request.task_type operation = delete_operation - elif isinstance(request, DoTaskRequest): - task_type = request.template.type - operation = do_operation - if request.inputs: - input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) else: context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") @@ -94,7 +90,7 @@ async def wrapper( return wrapper -class AsyncAgentService(AsyncAgentServiceServicer): +class AgentService(AgentServiceServicer): @agent_exception_handler async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: tmp = TaskTemplate.from_flyte_idl(request.template) @@ -130,25 +126,3 @@ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerCon if agent.asynchronous: return await agent.async_delete(context=context, resource_meta=request.resource_meta) return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta) - - -class SyncAgentService(SyncAgentServiceServicer): - @agent_exception_handler - async def DoTask(self, request: DoTaskRequest, context: grpc.ServicerContext) -> DoTaskResponse: - tmp = TaskTemplate.from_flyte_idl(request.template) - inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(tmp.type) - - logger.info(f"{tmp.type} agent start doing the job") - if agent.asynchronous: - return await agent.async_do( - context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp - ) - return await asyncio.get_running_loop().run_in_executor( - None, - agent.do, - context, - request.output_prefix, - tmp, - inputs, - ) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 252496b58a..82746d08c8 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -16,7 +16,6 @@ SUCCEEDED, CreateTaskResponse, DeleteTaskResponse, - DoTaskResponse, GetTaskResponse, State, ) @@ -85,18 +84,6 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT """ raise NotImplementedError - def do( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - ) -> DoTaskResponse: - """ - Return the result of executing a task. It should return error code if the task execution failed. - """ - raise NotImplementedError - async def async_create( self, context: grpc.ServicerContext, @@ -123,18 +110,6 @@ async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes """ raise NotImplementedError - async def async_do( - self, - context: grpc.ServicerContext, - output_prefix: str, - task_template: TaskTemplate, - inputs: typing.Optional[LiteralMap] = None, - ) -> DoTaskResponse: - """ - Return the result of executing a task. It should return error code if the task execution failed. - """ - raise NotImplementedError - class AgentRegistry(object): """ @@ -191,10 +166,13 @@ def _get_grpc_context() -> grpc.ServicerContext: return grpc_ctx -class AsyncAgentExecutorMixin: +class AgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. - Asynchronous task should inherit from this class if the task can be run in the agent. + Task should inherit from this class if the task can be run in the agent. + It can handle asynchronous tasks and synchronous tasks. + Asynchronous tasks are for tasks running long, for example running query job. + Synchronous tasks are for tasks running quick, for example, you want to execute something really fast, or even retrieving some metadata from a backend service. """ _clean_up_task: coroutine = None @@ -215,6 +193,13 @@ def execute(self, **kwargs) -> typing.Any: self._agent = AgentRegistry.get_agent(task_template.type) res = asyncio.run(self._create(task_template, output_prefix, kwargs)) + + # If the task is synchronous, the agent will return the output from the resource literals. + if res.HasField("resource"): + if res.resource.state != SUCCEEDED: + raise FlyteUserException(f"Failed to run the task {self._entity.name}") + return LiteralMap.from_flyte_idl(res.resource.outputs) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) if res.resource.state != SUCCEEDED: @@ -233,7 +218,6 @@ async def _create( self, task_template: TaskTemplate, output_prefix: str, inputs: typing.Dict[str, typing.Any] = None ) -> CreateTaskResponse: ctx = FlyteContext.current_context() - grpc_ctx = _get_grpc_context() # Convert python inputs to literals literals = inputs or {} @@ -248,9 +232,9 @@ async def _create( task_template = render_task_template(task_template, output_prefix) if self._agent.asynchronous: - res = await self._agent.async_create(self._grpc_ctx, output_prefix, task_template, inputs) + res = await self._agent.async_create(self._grpc_ctx, output_prefix, task_template, literal_map) else: - res = self._agent.create(self._grpc_ctx, output_prefix, task_template, inputs) + res = self._agent.create(self._grpc_ctx, output_prefix, task_template, literal_map) signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore return res @@ -267,8 +251,8 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: time.sleep(1) if self._agent.asynchronous: res = await self._agent.async_get(grpc_ctx, resource_meta) - if self._is_canceled: - await self._is_canceled + if self._clean_up_task: + await self._clean_up_task sys.exit(1) else: res = self._agent.get(grpc_ctx, resource_meta) @@ -276,16 +260,6 @@ async def _get(self, resource_meta: bytes) -> GetTaskResponse: logger.info(f"Task state: {state}, State message: {res.resource.message}") return res - async def _do(self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None): - inputs = self.get_input_literal_map(inputs) - output_prefix = self._ctx.file_access.get_random_local_directory() - - if self._agent.asynchronous: - res = await self._agent.async_do(self._grpc_ctx, output_prefix, task_template, inputs) - else: - res = self._agent.do(self._grpc_ctx, output_prefix, task_template, inputs) - return res - def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: if self._agent.asynchronous: if self._clean_up_task is None: diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index c7b994ac9b..4e67e67a4c 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -3,7 +3,7 @@ import grpc import jsonpickle -from flyteidl.admin.agent_pb2 import DoTaskResponse +from flyteidl.admin.agent_pb2 import CreateTaskResponse from flytekit import FlyteContextManager from flytekit.core.external_api_task import TASK_CONFIG_PKL, TASK_MODULE, TASK_NAME, TASK_TYPE @@ -27,17 +27,18 @@ class TaskExecutor(AgentBase): def __init__(self): super().__init__(task_type=TASK_TYPE, asynchronous=True) - async def async_do( + async def async_create( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, - ) -> DoTaskResponse: + ) -> CreateTaskResponse: python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } ctx = FlyteContextManager.current_context() + native_inputs = {} if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index c39f8dea37..7ae03d37a6 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -8,11 +8,10 @@ from flytekit.exceptions import user as _user_exceptions from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from flytekit.models.types import Error +from flytekit.models.types import Error, StructuredDatasetType from flytekit.models.types import LiteralType as _LiteralType from flytekit.models.types import OutputReference as _OutputReference from flytekit.models.types import SchemaType as _SchemaType -from flytekit.models.types import StructuredDatasetType class RetryStrategy(_common.FlyteIdlEntity): diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 19439250a1..ff6566888b 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -9,7 +9,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin T = TypeVar("T") SENSOR_MODULE = "sensor_module" @@ -19,13 +19,15 @@ INPUTS = "inputs" -class BaseSensor(AsyncAgentExecutorMixin, PythonTask): +class BaseSensor(AgentExecutorMixin, PythonTask): """ Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the sensor agent, and not by the Flyte engine. """ + is_sync = False + def __init__( self, name: str, @@ -45,7 +47,7 @@ def __init__( name=name, task_config=None, interface=Interface(inputs=inputs), - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) self._sensor_config = sensor_config diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index 02edea96bb..3f7b1b7a69 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -39,9 +39,11 @@ async def async_create( name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } ctx = FlyteContextManager.current_context() + if inputs: native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) task_template.custom[INPUTS] = native_inputs + return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index 007716c279..cab445fdfa 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -19,7 +19,7 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin @dataclass @@ -106,12 +106,13 @@ def execute(self, **kwargs) -> Any: _get_airflow_instance(self.task_config).execute(context=Context()) -class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowObj]): +class AirflowTask(AgentExecutorMixin, PythonTask[AirflowObj]): """ This python task is used to wrap an Airflow task. It is used to run an Airflow task in Flyte agent. The airflow task module, name and parameters are stored in the task config. We run the Airflow task in the agent. """ + is_sync = False _TASK_TYPE = "airflow" def __init__( @@ -126,7 +127,7 @@ def __init__( task_config=task_config, interface=Interface(inputs=inputs or {}), task_type=self._TASK_TYPE, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 33328e5bfb..789151d7be 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -35,6 +35,7 @@ class AWSBatchFunctionTask(PythonFunctionTask): Actual Plugin that transforms the local python code for execution within AWS batch job """ + is_sync = False _AWS_BATCH_TASK_TYPE = "aws-batch" def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwargs): @@ -44,7 +45,7 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) self._task_config = task_config diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 2b686d01a2..ebe39deb81 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -7,7 +7,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -23,13 +23,14 @@ class BigQueryConfig(object): QueryJobConfig: Optional[bigquery.QueryJobConfig] = None -class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]): +class BigQueryTask(AgentExecutorMixin, SQLTask[BigQueryConfig]): """ This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output. """ # This task is executed using the BigQuery handler in the backend. # https://github.com/flyteorg/flyteplugins/blob/43623826fb189fa64dc4cb53e7025b517d911f22/go/tasks/plugins/webapi/bigquery/plugin.go#L34 + is_sync = False _TASK_TYPE = "bigquery_query_job_task" def __init__( @@ -63,7 +64,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) self._output_structured_dataset_type = output_structured_dataset_type diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py index e826819b69..b5a175be6a 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py @@ -24,6 +24,7 @@ class MMCloudConfig(object): class MMCloudTask(PythonFunctionTask): + is_sync = False _TASK_TYPE = "mmcloud_task" def __init__( @@ -40,7 +41,7 @@ def __init__( task_type=self._TASK_TYPE, task_function=task_function, container_image=container_image, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 0bac5c9c70..5b606bb824 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -2,7 +2,7 @@ from typing import Any, Dict import openai -from flyteidl.admin.agent_pb2 import SUCCEEDED, DoTaskResponse, Resource +from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource from flytekit import FlyteContextManager from flytekit.core.external_api_task import ExternalApiTask @@ -43,7 +43,7 @@ def __init__(self, name: str, config: Dict[str, Any], **kwargs): async def do( self, message: str = None, - ) -> DoTaskResponse: + ) -> CreateTaskResponse: openai.organization = self._openai_organization openai.api_key = get_agent_secret(secret_key="FLYTE_OPENAI_ACCESS_TOKEN") @@ -63,4 +63,4 @@ async def do( ) } ).to_flyte_idl() - return DoTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) + return CreateTaskResponse(resource=Resource(state=SUCCEEDED, outputs=outputs)) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 174650d8fa..86a9906fdc 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,7 +3,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -35,12 +35,13 @@ class SnowflakeConfig(object): table: Optional[str] = None -class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): +class SnowflakeTask(AgentExecutorMixin, SQLTask[SnowflakeConfig]): """ This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output. """ # This task is executed using the snowflake handler in the backend. + is_sync = False _TASK_TYPE = "snowflake" def __init__( @@ -77,7 +78,7 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 7a802319bb..c448b6ff29 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,7 +9,7 @@ from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins -from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extend.backend.base_agent import AgentExecutorMixin from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -99,11 +99,12 @@ def new_spark_session(name: str, conf: Dict[str, str] = None): # sess.stop() -class PysparkFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): +class PysparkFunctionTask(AgentExecutorMixin, PythonFunctionTask[Spark]): """ Actual Plugin that transforms the local python code for execution within a spark context """ + is_sync = False _SPARK_TASK_TYPE = "spark" def __init__( @@ -131,7 +132,7 @@ def __init__( task_type=self._SPARK_TASK_TYPE, task_function=task_function, container_image=container_image, - is_sync_plugin=False, + is_sync_plugin=self.is_sync, **kwargs, ) @@ -185,7 +186,7 @@ def execute(self, **kwargs) -> Any: " please set --raw-output-data-prefix to a remote path. e.g. s3://, gcs//, etc." ) if ctx.execution_state and ctx.execution_state.is_local_execution(): - return AsyncAgentExecutorMixin.execute(self, **kwargs) + return AgentExecutorMixin.execute(self, **kwargs) except Exception as e: logger.error(f"Agent failed to run the task with error: {e}") logger.info("Falling back to local execution") diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index 6f869f78d2..d551c76192 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -1,10 +1,10 @@ import datetime import pytest +from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit import __version__ from flytekit.core.base_task import TaskMetadata -from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit.models import literals as _literal_models from flytekit.models import task as _task_model diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 45bb8da76c..3d03c05ce2 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -28,9 +28,8 @@ from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( AgentBase, + AgentExecutorMixin, AgentRegistry, - AsyncAgentExecutorMixin, - SyncAgentExecutorMixin, convert_to_flyte_state, get_agent_secret, is_terminal_state, @@ -119,9 +118,6 @@ def simple_task(i: int): ) - - - async_dummy_template = get_task_template("async_dummy") sync_dummy_template = get_task_template("sync_dummy", True) @@ -136,7 +132,7 @@ def test_dummy_agent(): assert async_agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() assert sync_agent.do(ctx, sync_dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) - class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): + class AsyncDummyTask(AgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="async_dummy", @@ -147,7 +143,7 @@ def __init__(self, **kwargs): t = AsyncDummyTask(task_config={}, task_function=lambda: None, container_image="dummy") t.execute() - class SyncDummyTask(SyncAgentExecutorMixin, PythonFunctionTask): + class SyncDummyTask(AgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="sync_dummy", From 9d1e5bd2129faf7fe042fc683b52b74517981f2d Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 15:04:50 +0800 Subject: [PATCH 57/64] rename async agent service Signed-off-by: Future Outlier --- flytekit/clis/sdk_in_container/serve.py | 6 ++--- flytekit/core/base_task.py | 14 ++--------- flytekit/core/external_api_task.py | 4 +-- flytekit/extend/backend/agent_service.py | 6 ++--- flytekit/extend/backend/base_agent.py | 2 +- flytekit/models/task.py | 25 +++---------------- flytekit/sensor/base_sensor.py | 4 +-- .../flytekitplugins/airflow/task.py | 4 +-- .../flytekitplugins/bigquery/task.py | 4 +-- .../flytekitplugins/chatgpt/task.py | 4 +-- .../flytekitplugins/snowflake/task.py | 4 +-- .../flytekitplugins/spark/task.py | 6 ++--- tests/flytekit/unit/extend/test_agent.py | 6 ++--- 13 files changed, 31 insertions(+), 58 deletions(-) diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 9d48ff4c4d..6a262c91a4 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,7 +1,7 @@ from concurrent import futures import rich_click as click -from flyteidl.service.agent_pb2_grpc import add_AgentServiceServicer_to_server +from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server from grpc import aio @@ -49,7 +49,7 @@ def agent(_: click.Context, port, worker, timeout): async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") - from flytekit.extend.backend.agent_service import AgentService + from flytekit.extend.backend.agent_service import AsyncAgentService try: from prometheus_client import start_http_server @@ -60,7 +60,7 @@ async def _start_grpc_server(port: int, worker: int, timeout: int): click.secho("Starting the agent service...", fg="blue") server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) - add_AgentServiceServicer_to_server(AgentService(), server) + add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) server.add_insecure_port(f"[::]:{port}") await server.start() diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 64fdddc59f..21c7279c53 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -26,7 +26,6 @@ from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flyteidl.core import tasks_pb2 -from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ( @@ -107,7 +106,6 @@ class TaskMetadata(object): retries: int = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None - plugin_metadata: Optional[PluginMetadata] = None def __post_init__(self): if self.timeout: @@ -133,10 +131,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata: return _task_model.TaskMetadata( discoverable=self.cache, runtime=_task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "python", - self.plugin_metadata, + _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python" ), timeout=self.timeout, retries=self.retry_strategy, @@ -176,14 +171,12 @@ def __init__( task_type_version=0, security_ctx: Optional[SecurityContext] = None, docs: Optional[Documentation] = None, - is_sync_plugin: Optional[bool] = None, **kwargs, ): self._task_type = task_type self._name = name self._interface = interface - plugin_metadata = PluginMetadata(is_sync_plugin=is_sync_plugin) if is_sync_plugin is not None else None - self._metadata = metadata if metadata else TaskMetadata(plugin_metadata=plugin_metadata) + self._metadata = metadata if metadata else TaskMetadata() self._task_type_version = task_type_version self._security_ctx = security_ctx self._docs = docs @@ -427,7 +420,6 @@ def __init__( environment: Optional[Dict[str, str]] = None, disable_deck: Optional[bool] = None, enable_deck: Optional[bool] = None, - is_sync_plugin: Optional[bool] = None, **kwargs, ): """ @@ -443,13 +435,11 @@ def __init__( execution of the task. Supplied as a dictionary of key/value pairs disable_deck (bool): (deprecated) If true, this task will not output deck html file enable_deck (bool): If true, this task will output deck html file - is_sync_plugin (bool): If true, plugin task will execute synchronously. """ super().__init__( task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), - is_sync_plugin=is_sync_plugin, **kwargs, ) self._python_interface = interface if interface else Interface() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index c006b3b9ff..87da2a694e 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -10,7 +10,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin T = TypeVar("T") TASK_MODULE = "task_module" @@ -19,7 +19,7 @@ TASK_TYPE = "api_task" -class ExternalApiTask(AgentExecutorMixin, PythonTask): +class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): """ Base class for all external API tasks. External API tasks are tasks that are designed to run until they receive a response from an external service. When the response is received, the task will complete. External API tasks are diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index c72869debf..aecc000a82 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -10,7 +10,7 @@ GetTaskRequest, GetTaskResponse, ) -from flyteidl.service.agent_pb2_grpc import AgentServiceServicer +from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer from prometheus_client import Counter, Summary from flytekit import logger @@ -90,13 +90,13 @@ async def wrapper( return wrapper -class AgentService(AgentServiceServicer): +class AsyncAgentService(AsyncAgentServiceServicer): @agent_exception_handler async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) - + print("@@@ we are using agent server") logger.info(f"{tmp.type} agent start creating the job") if agent.asynchronous: return await agent.async_create( diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 82746d08c8..601bc3a1e3 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -166,7 +166,7 @@ def _get_grpc_context() -> grpc.ServicerContext: return grpc_ctx -class AgentExecutorMixin: +class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. diff --git a/flytekit/models/task.py b/flytekit/models/task.py index ab37dad5dd..48a8abfde1 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -116,18 +116,16 @@ class RuntimeType(object): OTHER = 0 FLYTE_SDK = 1 - def __init__(self, type, version, flavor, plugin_metadata): + def __init__(self, type, version, flavor): """ :param int type: Enum type from RuntimeMetadata.RuntimeType :param Text version: Version string for SDK version. Can be used for metrics or managing breaking changes in Admin or Propeller :param Text flavor: Optional extra information about runtime environment (e.g. Python, GoLang, etc.) - :param Boolean is_sync_plugin: Boolean to indicate if the plugin is sync or async """ self._type = type self._version = version self._flavor = flavor - self._plugin_metadata = plugin_metadata @property def type(self): @@ -148,26 +146,16 @@ def version(self): @property def flavor(self): """ - Optional extra information about the plugin type (e.g. async plugin, sync plugin... etc.). + Optional extra information about runtime environment (e.g. Python, GoLang, etc.) :rtype: Text """ return self._flavor - @property - def plugin_metadata(self): - """ - Boolean to indicate if the plugin is sync or async - :rtype: Boolean - """ - return self._plugin_metadata - def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata """ - return _core_task.RuntimeMetadata( - type=self.type, version=self.version, flavor=self.flavor, plugin_metadata=self._plugin_metadata - ) + return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) @classmethod def from_flyte_idl(cls, pb2_object): @@ -175,12 +163,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.tasks_pb2.RuntimeMetadata pb2_object: :rtype: RuntimeMetadata """ - return cls( - type=pb2_object.type, - version=pb2_object.version, - flavor=pb2_object.flavor, - plugin_metadata=pb2_object.plugin_metadata if pb2_object.HasField("plugin_metadata") else None, - ) + return cls(type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor) class TaskMetadata(_common.FlyteIdlEntity): diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index ff6566888b..7120b4f09e 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -9,7 +9,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.interface import Interface -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin T = TypeVar("T") SENSOR_MODULE = "sensor_module" @@ -19,7 +19,7 @@ INPUTS = "inputs" -class BaseSensor(AgentExecutorMixin, PythonTask): +class BaseSensor(AsyncAgentExecutorMixin, PythonTask): """ Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index cab445fdfa..ef60c6ecc3 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -19,7 +19,7 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.tracker import TrackedInstance from flytekit.core.utils import timeit -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin @dataclass @@ -106,7 +106,7 @@ def execute(self, **kwargs) -> Any: _get_airflow_instance(self.task_config).execute(context=Context()) -class AirflowTask(AgentExecutorMixin, PythonTask[AirflowObj]): +class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowObj]): """ This python task is used to wrap an Airflow task. It is used to run an Airflow task in Flyte agent. The airflow task module, name and parameters are stored in the task config. We run the Airflow task in the agent. diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index ebe39deb81..d20f04b993 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -7,7 +7,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -23,7 +23,7 @@ class BigQueryConfig(object): QueryJobConfig: Optional[bigquery.QueryJobConfig] = None -class BigQueryTask(AgentExecutorMixin, SQLTask[BigQueryConfig]): +class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]): """ This is the simplest form of a BigQuery Task, that can be used even for tasks that do not produce any output. """ diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 5b606bb824..018068bbd6 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -49,8 +49,8 @@ async def do( self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) - message = completion.choices[0].message.content + # completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) + # message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() outputs = LiteralMap( diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 86a9906fdc..e12c4e6d90 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,7 +3,7 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model from flytekit.types.structured import StructuredDataset @@ -35,7 +35,7 @@ class SnowflakeConfig(object): table: Optional[str] = None -class SnowflakeTask(AgentExecutorMixin, SQLTask[SnowflakeConfig]): +class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): """ This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output. """ diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index c448b6ff29..4dbf030bdf 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -9,7 +9,7 @@ from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.context_manager import ExecutionParameters from flytekit.extend import ExecutionState, TaskPlugins -from flytekit.extend.backend.base_agent import AgentExecutorMixin +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.image_spec import ImageSpec from .models import SparkJob, SparkType @@ -99,7 +99,7 @@ def new_spark_session(name: str, conf: Dict[str, str] = None): # sess.stop() -class PysparkFunctionTask(AgentExecutorMixin, PythonFunctionTask[Spark]): +class PysparkFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): """ Actual Plugin that transforms the local python code for execution within a spark context """ @@ -186,7 +186,7 @@ def execute(self, **kwargs) -> Any: " please set --raw-output-data-prefix to a remote path. e.g. s3://, gcs//, etc." ) if ctx.execution_state and ctx.execution_state.is_local_execution(): - return AgentExecutorMixin.execute(self, **kwargs) + return AsyncAgentExecutorMixin.execute(self, **kwargs) except Exception as e: logger.error(f"Agent failed to run the task with error: {e}") logger.info("Falling back to local execution") diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 3d03c05ce2..dc11ed62c1 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -28,7 +28,7 @@ from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( AgentBase, - AgentExecutorMixin, + AsyncAgentExecutorMixin, AgentRegistry, convert_to_flyte_state, get_agent_secret, @@ -132,7 +132,7 @@ def test_dummy_agent(): assert async_agent.delete(ctx, metadata_bytes) == DeleteTaskResponse() assert sync_agent.do(ctx, sync_dummy_template, task_inputs) == DoTaskResponse(resource=Resource(state=SUCCEEDED)) - class AsyncDummyTask(AgentExecutorMixin, PythonFunctionTask): + class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="async_dummy", @@ -143,7 +143,7 @@ def __init__(self, **kwargs): t = AsyncDummyTask(task_config={}, task_function=lambda: None, container_image="dummy") t.execute() - class SyncDummyTask(AgentExecutorMixin, PythonFunctionTask): + class SyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="sync_dummy", From b09a627f340271f1dd65eb04f72da8fa2339bf4c Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 17:23:43 +0800 Subject: [PATCH 58/64] remove is_sync Signed-off-by: Future Outlier --- flytekit/core/base_sql_task.py | 2 -- flytekit/core/external_api_task.py | 3 --- .../flytekitplugins/mmcloud/task.py | 2 -- .../flytekitplugins/snowflake/task.py | 1 - .../flytekit/unit/core/test_task_metadata.py | 22 ------------------- tests/flytekit/unit/extend/test_agent.py | 2 -- 6 files changed, 32 deletions(-) diff --git a/flytekit/core/base_sql_task.py b/flytekit/core/base_sql_task.py index aecbdb88c5..30b73223a9 100644 --- a/flytekit/core/base_sql_task.py +++ b/flytekit/core/base_sql_task.py @@ -27,7 +27,6 @@ def __init__( inputs: Optional[Dict[str, Tuple[Type, Any]]] = None, metadata: Optional[TaskMetadata] = None, outputs: Optional[Dict[str, Type]] = None, - is_sync_plugin: Optional[bool] = None, **kwargs, ): """ @@ -40,7 +39,6 @@ def __init__( interface=Interface(inputs=inputs or {}, outputs=outputs or {}), metadata=metadata, task_config=task_config, - is_sync_plugin=is_sync_plugin, **kwargs, ) self._query_template = re.sub(r"\s+", " ", query_template.replace("\n", " ").replace("\t", " ")).strip() diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 87da2a694e..9bf3a99a67 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -26,8 +26,6 @@ class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): designed to be run by the flyte agent. """ - is_sync = True - def __init__( self, name: str, @@ -50,7 +48,6 @@ def __init__( name=name, task_config=config, interface=Interface(inputs=inputs, outputs=outputs), - is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py index b5a175be6a..3a61d590d7 100644 --- a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py @@ -24,7 +24,6 @@ class MMCloudConfig(object): class MMCloudTask(PythonFunctionTask): - is_sync = False _TASK_TYPE = "mmcloud_task" def __init__( @@ -41,7 +40,6 @@ def __init__( task_type=self._TASK_TYPE, task_function=task_function, container_image=container_image, - is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index e12c4e6d90..f6861e46db 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -41,7 +41,6 @@ class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): """ # This task is executed using the snowflake handler in the backend. - is_sync = False _TASK_TYPE = "snowflake" def __init__( diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index d551c76192..d4edae8752 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -33,7 +33,6 @@ def test_retry_strategy(): def test_to_task_metadata_model(): - # Test the value of is_sync_plugin is True tm = TaskMetadata( cache=True, cache_serialize=True, @@ -43,7 +42,6 @@ def test_to_task_metadata_model(): retries=3, timeout=3600, pod_template_name="TEST POD TEMPLATE NAME", - plugin_metadata=PluginMetadata(is_sync_plugin=True), ) model = tm.to_taskmetadata_model() @@ -52,7 +50,6 @@ def test_to_task_metadata_model(): _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", - plugin_metadata=PluginMetadata(is_sync_plugin=True), ) assert model.retries == _literal_models.RetryStrategy(3) assert model.timeout == datetime.timedelta(seconds=3600) @@ -62,22 +59,3 @@ def test_to_task_metadata_model(): assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" - # Test the value of is_sync_plugin is False - tm = TaskMetadata(plugin_metadata=PluginMetadata(is_sync_plugin=False)) - model = tm.to_taskmetadata_model() - assert model.runtime == _task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "python", - plugin_metadata=PluginMetadata(is_sync_plugin=False), - ) - - # Test the default value of is_sync_plugin is None - tm = TaskMetadata() - model = tm.to_taskmetadata_model() - assert model.runtime == _task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - "python", - None, - ) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index dc11ed62c1..6236c9bfe7 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -136,7 +136,6 @@ class AsyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="async_dummy", - is_sync_plugin=False, **kwargs, ) @@ -147,7 +146,6 @@ class SyncDummyTask(AsyncAgentExecutorMixin, PythonFunctionTask): def __init__(self, **kwargs): super().__init__( task_type="sync_dummy", - is_sync_plugin=True, **kwargs, ) From 5a0348b983ad01b5728ff291c534ee7c4262e333 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 17:29:28 +0800 Subject: [PATCH 59/64] remove is_sync Signed-off-by: Future Outlier --- flytekit/core/python_auto_container.py | 2 -- flytekit/core/python_function_task.py | 2 -- flytekit/sensor/base_sensor.py | 3 --- plugins/flytekit-airflow/flytekitplugins/airflow/task.py | 2 -- plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py | 2 -- plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py | 2 -- .../flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py | 1 - plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py | 1 - plugins/flytekit-spark/flytekitplugins/spark/task.py | 2 -- tests/flytekit/unit/core/test_task_metadata.py | 2 -- tests/flytekit/unit/extend/test_agent.py | 2 +- 11 files changed, 1 insertion(+), 20 deletions(-) diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index bad94beead..2f9d8417fd 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -48,7 +48,6 @@ def __init__( pod_template: Optional[PodTemplate] = None, pod_template_name: Optional[str] = None, accelerator: Optional[BaseAccelerator] = None, - is_sync_plugin: Optional[bool] = None, **kwargs, ): """ @@ -93,7 +92,6 @@ def __init__( name=name, task_config=task_config, security_ctx=sec_ctx, - is_sync_plugin=is_sync_plugin, **kwargs, ) self._container_image = container_image diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 8edcbad578..e1e80a4227 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -102,7 +102,6 @@ def __init__( ignore_input_vars: Optional[List[str]] = None, execution_mode: ExecutionBehavior = ExecutionBehavior.DEFAULT, task_resolver: Optional[TaskResolverMixin] = None, - is_sync_plugin: Optional[bool] = None, **kwargs, ): """ @@ -125,7 +124,6 @@ def __init__( interface=mutated_interface, task_config=task_config, task_resolver=task_resolver, - is_sync_plugin=is_sync_plugin, **kwargs, ) diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py index 7120b4f09e..fed5f6493b 100644 --- a/flytekit/sensor/base_sensor.py +++ b/flytekit/sensor/base_sensor.py @@ -26,8 +26,6 @@ class BaseSensor(AsyncAgentExecutorMixin, PythonTask): sensor agent, and not by the Flyte engine. """ - is_sync = False - def __init__( self, name: str, @@ -47,7 +45,6 @@ def __init__( name=name, task_config=None, interface=Interface(inputs=inputs), - is_sync_plugin=self.is_sync, **kwargs, ) self._sensor_config = sensor_config diff --git a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py index ef60c6ecc3..a9bed9a580 100644 --- a/plugins/flytekit-airflow/flytekitplugins/airflow/task.py +++ b/plugins/flytekit-airflow/flytekitplugins/airflow/task.py @@ -112,7 +112,6 @@ class AirflowTask(AsyncAgentExecutorMixin, PythonTask[AirflowObj]): The airflow task module, name and parameters are stored in the task config. We run the Airflow task in the agent. """ - is_sync = False _TASK_TYPE = "airflow" def __init__( @@ -127,7 +126,6 @@ def __init__( task_config=task_config, interface=Interface(inputs=inputs or {}), task_type=self._TASK_TYPE, - is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 789151d7be..d5568ab041 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -35,7 +35,6 @@ class AWSBatchFunctionTask(PythonFunctionTask): Actual Plugin that transforms the local python code for execution within AWS batch job """ - is_sync = False _AWS_BATCH_TASK_TYPE = "aws-batch" def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwargs): @@ -45,7 +44,6 @@ def __init__(self, task_config: AWSBatchConfig, task_function: Callable, **kwarg task_config=task_config, task_type=self._AWS_BATCH_TASK_TYPE, task_function=task_function, - is_sync_plugin=self.is_sync, **kwargs, ) self._task_config = task_config diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index d20f04b993..bcc707da5a 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -30,7 +30,6 @@ class BigQueryTask(AsyncAgentExecutorMixin, SQLTask[BigQueryConfig]): # This task is executed using the BigQuery handler in the backend. # https://github.com/flyteorg/flyteplugins/blob/43623826fb189fa64dc4cb53e7025b517d911f22/go/tasks/plugins/webapi/bigquery/plugin.go#L34 - is_sync = False _TASK_TYPE = "bigquery_query_job_task" def __init__( @@ -64,7 +63,6 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - is_sync_plugin=self.is_sync, **kwargs, ) self._output_structured_dataset_type = output_structured_dataset_type diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 018068bbd6..88bf0100ac 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,4 +1,3 @@ -import asyncio from typing import Any, Dict import openai diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index f6861e46db..9ac9980a88 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -77,7 +77,6 @@ def __init__( inputs=inputs, outputs=outputs, task_type=self._TASK_TYPE, - is_sync_plugin=self.is_sync, **kwargs, ) self._output_schema_type = output_schema_type diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 4dbf030bdf..6c692fb726 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -104,7 +104,6 @@ class PysparkFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[Spark]): Actual Plugin that transforms the local python code for execution within a spark context """ - is_sync = False _SPARK_TASK_TYPE = "spark" def __init__( @@ -132,7 +131,6 @@ def __init__( task_type=self._SPARK_TASK_TYPE, task_function=task_function, container_image=container_image, - is_sync_plugin=self.is_sync, **kwargs, ) diff --git a/tests/flytekit/unit/core/test_task_metadata.py b/tests/flytekit/unit/core/test_task_metadata.py index d4edae8752..a158a3ac31 100644 --- a/tests/flytekit/unit/core/test_task_metadata.py +++ b/tests/flytekit/unit/core/test_task_metadata.py @@ -1,7 +1,6 @@ import datetime import pytest -from flyteidl.core.tasks_pb2 import PluginMetadata from flytekit import __version__ from flytekit.core.base_task import TaskMetadata @@ -58,4 +57,3 @@ def test_to_task_metadata_model(): assert model.deprecated_error_message == "TEST DEPRECATED ERROR MESSAGE" assert model.cache_serializable is True assert model.pod_template_name == "TEST POD TEMPLATE NAME" - diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 6236c9bfe7..75719d619d 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -28,8 +28,8 @@ from flytekit.extend.backend.agent_service import AsyncAgentService, SyncAgentService from flytekit.extend.backend.base_agent import ( AgentBase, - AsyncAgentExecutorMixin, AgentRegistry, + AsyncAgentExecutorMixin, convert_to_flyte_state, get_agent_secret, is_terminal_state, From ea91dc7b8c4bf76f3868e7ceb8a733ee108ff2de Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 17:56:16 +0800 Subject: [PATCH 60/64] fix pickle error Signed-off-by: Future Outlier --- flytekit/core/external_api_task.py | 9 +++------ flytekit/extend/backend/task_executor.py | 5 ++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index 9bf3a99a67..baab291fdf 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -3,7 +3,6 @@ from abc import abstractmethod from typing import Any, Dict, Optional, TypeVar -import jsonpickle from flyteidl.admin.agent_pb2 import CreateTaskResponse from typing_extensions import get_type_hints @@ -15,7 +14,7 @@ T = TypeVar("T") TASK_MODULE = "task_module" TASK_NAME = "task_name" -TASK_CONFIG_PKL = "task_config_pkl" +TASK_CONFIG = "task_config" TASK_TYPE = "api_task" @@ -29,7 +28,7 @@ class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): def __init__( self, name: str, - config: Optional[T] = None, + config: Optional[Dict[str, Any]] = None, task_type: str = TASK_TYPE, return_type: Optional[Any] = None, **kwargs, @@ -64,9 +63,7 @@ def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { TASK_MODULE: type(self).__module__, TASK_NAME: type(self).__name__, + TASK_CONFIG: self._config, } - if self._config is not None: - cfg[TASK_CONFIG_PKL] = jsonpickle.encode(self._config) - return cfg diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index 4e67e67a4c..d8867aa015 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -2,11 +2,10 @@ import typing import grpc -import jsonpickle from flyteidl.admin.agent_pb2 import CreateTaskResponse from flytekit import FlyteContextManager -from flytekit.core.external_api_task import TASK_CONFIG_PKL, TASK_MODULE, TASK_NAME, TASK_TYPE +from flytekit.core.external_api_task import TASK_CONFIG, TASK_MODULE, TASK_NAME, TASK_TYPE from flytekit.core.type_engine import TypeEngine from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry from flytekit.models.literals import LiteralMap @@ -47,7 +46,7 @@ async def async_create( task_module = importlib.import_module(name=meta[TASK_MODULE]) task_def = getattr(task_module, meta[TASK_NAME]) - config = jsonpickle.decode(meta[TASK_CONFIG_PKL]) if meta.get(TASK_CONFIG_PKL) else None + config = meta[TASK_CONFIG] if meta.get(TASK_CONFIG) else None return await task_def(TASK_TYPE, config=config).do(**native_inputs) From 7786acaba48e1f536b9d3f1f62f340f5a1f4a341 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 18:00:58 +0800 Subject: [PATCH 61/64] remove new line Signed-off-by: Future Outlier --- flytekit/extend/backend/agent_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index aecc000a82..7b17e6e2a9 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -96,7 +96,6 @@ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerCon tmp = TaskTemplate.from_flyte_idl(request.template) inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None agent = AgentRegistry.get_agent(tmp.type) - print("@@@ we are using agent server") logger.info(f"{tmp.type} agent start creating the job") if agent.asynchronous: return await agent.async_create( From ca34c6113088b204f3be2e811d07e40451016d53 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Tue, 19 Dec 2023 23:45:49 +0800 Subject: [PATCH 62/64] rename task config Signed-off-by: Future Outlier --- flytekit/core/external_api_task.py | 8 +++++--- .../flytekitplugins/chatgpt/task.py | 6 +++--- tests/flytekit/unit/core/test_external_api_task.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/flytekit/core/external_api_task.py b/flytekit/core/external_api_task.py index baab291fdf..3d8942f346 100644 --- a/flytekit/core/external_api_task.py +++ b/flytekit/core/external_api_task.py @@ -28,7 +28,7 @@ class ExternalApiTask(AsyncAgentExecutorMixin, PythonTask): def __init__( self, name: str, - config: Optional[Dict[str, Any]] = None, + config: Optional[T] = None, task_type: str = TASK_TYPE, return_type: Optional[Any] = None, **kwargs, @@ -50,7 +50,7 @@ def __init__( **kwargs, ) - self._config = config + self._task_config = config @abstractmethod async def do(self, **kwargs) -> CreateTaskResponse: @@ -63,7 +63,9 @@ def get_custom(self, settings: SerializationSettings = None) -> Dict[str, Any]: cfg = { TASK_MODULE: type(self).__module__, TASK_NAME: type(self).__name__, - TASK_CONFIG: self._config, } + if self._task_config is not None: + cfg[TASK_CONFIG] = self._task_config + return cfg diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 88bf0100ac..7b5c128cf8 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,5 +1,5 @@ from typing import Any, Dict - +import asyncio import openai from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource @@ -48,8 +48,8 @@ async def do( self._chatgpt_conf["messages"] = [{"role": "user", "content": message}] - # completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) - # message = completion.choices[0].message.content + completion = await asyncio.wait_for(openai.ChatCompletion.acreate(**self._chatgpt_conf), TIMEOUT_SECONDS) + message = completion.choices[0].message.content ctx = FlyteContextManager.current_context() outputs = LiteralMap( diff --git a/tests/flytekit/unit/core/test_external_api_task.py b/tests/flytekit/unit/core/test_external_api_task.py index dde494c822..cbe0ffa181 100644 --- a/tests/flytekit/unit/core/test_external_api_task.py +++ b/tests/flytekit/unit/core/test_external_api_task.py @@ -3,7 +3,7 @@ import pytest -from flytekit.core.external_api_task import TASK_CONFIG_PKL, TASK_MODULE, TASK_NAME, ExternalApiTask +from flytekit.core.external_api_task import TASK_CONFIG, TASK_MODULE, TASK_NAME, ExternalApiTask from flytekit.core.interface import Interface, transform_interface_to_typed_interface @@ -37,4 +37,4 @@ def test_get_custom(): expected_config = json.loads('{"key": "value"}') assert custom[TASK_MODULE] == MockExternalApiTask.__module__ assert custom[TASK_NAME] == MockExternalApiTask.__name__ - assert json.loads(custom[TASK_CONFIG_PKL]) == expected_config + assert json.loads(custom[TASK_CONFIG]) == expected_config From 1d9535d157f541982029817a0799fa7319454918 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 20 Dec 2023 10:06:28 +0800 Subject: [PATCH 63/64] use SyncAgentBase Signed-off-by: Future Outlier --- flytekit/__init__.py | 2 +- flytekit/extend/backend/task_executor.py | 25 ++++++++++++++++--- .../flytekitplugins/chatgpt/task.py | 3 ++- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/flytekit/__init__.py b/flytekit/__init__.py index bcb1d2c54e..5f26b27664 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -246,7 +246,7 @@ StructuredDatasetType, ) -from flytekit.extend.backend.task_executor import TaskExecutor # isort:skip. This is for circular import avoidance. +from flytekit.extend.backend.task_executor import SyncAgentBase # isort:skip. This is for circular import avoidance. def current_context() -> ExecutionParameters: diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index d8867aa015..e902d787c8 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -1,5 +1,7 @@ import importlib import typing +from dataclasses import dataclass +from typing import final import grpc from flyteidl.admin.agent_pb2 import CreateTaskResponse @@ -14,9 +16,15 @@ T = typing.TypeVar("T") -class TaskExecutor(AgentBase): +@dataclass +class IOContext: + inputs: LiteralMap + output_prefix: str + + +class SyncAgentBase(AgentBase): """ - TaskExecutor is an agent responsible for executing external API tasks. + SyncAgentBase is an agent responsible for syncrhounous tasks, which are fast and quick. This class is meant to be subclassed when implementing plugins that require an external API to perform the task execution. It provides a routing mechanism @@ -26,12 +34,23 @@ class TaskExecutor(AgentBase): def __init__(self): super().__init__(task_type=TASK_TYPE, asynchronous=True) + @final async def async_create( self, context: grpc.ServicerContext, output_prefix: str, task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + print("@@@ output_prefix:", output_prefix) + return await self.do(context, output_prefix, task_template, inputs) + + async def do( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, ) -> CreateTaskResponse: python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() @@ -50,4 +69,4 @@ async def async_create( return await task_def(TASK_TYPE, config=config).do(**native_inputs) -AgentRegistry.register(TaskExecutor()) +AgentRegistry.register(SyncAgentBase()) diff --git a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py index 7b5c128cf8..5b606bb824 100644 --- a/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py +++ b/plugins/flytekit-openai-chatgpt/flytekitplugins/chatgpt/task.py @@ -1,5 +1,6 @@ -from typing import Any, Dict import asyncio +from typing import Any, Dict + import openai from flyteidl.admin.agent_pb2 import SUCCEEDED, CreateTaskResponse, Resource From 68fd52763159130d64e193e7a6ac4a29a3602322 Mon Sep 17 00:00:00 2001 From: Future Outlier Date: Wed, 20 Dec 2023 10:11:33 +0800 Subject: [PATCH 64/64] remove print Signed-off-by: Future Outlier --- flytekit/extend/backend/task_executor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/extend/backend/task_executor.py b/flytekit/extend/backend/task_executor.py index e902d787c8..22b6675cca 100644 --- a/flytekit/extend/backend/task_executor.py +++ b/flytekit/extend/backend/task_executor.py @@ -42,7 +42,6 @@ async def async_create( task_template: TaskTemplate, inputs: typing.Optional[LiteralMap] = None, ) -> CreateTaskResponse: - print("@@@ output_prefix:", output_prefix) return await self.do(context, output_prefix, task_template, inputs) async def do(