From e3ca0973c08bbc9abc605311d02252afe93f7fd0 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 4 Jun 2024 11:41:15 +0800 Subject: [PATCH 1/7] change interface order of output_prefix in do() method and fix chatgpt agent error Signed-off-by: Future-Outlier --- flytekit/extend/backend/base_agent.py | 8 ++++--- .../flytekitplugins/openai/chatgpt/agent.py | 1 + .../tests/chatgpt/test_chatgpt.py | 22 +++++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 33a03e282b..2ab37e4a08 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -119,7 +119,7 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: str, **kwargs) -> Resource: + def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs) -> Resource: """ This is the method that the agent will run. """ @@ -247,7 +247,9 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: agent = AgentRegistry.get_agent(task_template.type, task_template.task_type_version) - resource = asyncio.run(self._do(agent, task_template, output_prefix, kwargs)) + resource = asyncio.run( + self._do(agent=agent, template=task_template, output_prefix=output_prefix, inputs=kwargs) + ) if resource.phase != TaskExecution.SUCCEEDED: raise FlyteUserException(f"Failed to run the task {self.name} with error: {resource.message}") @@ -259,8 +261,8 @@ async def _do( self: PythonTask, agent: SyncAgentBase, template: TaskTemplate, - output_prefix: str, inputs: Dict[str, Any] = None, + output_prefix: Optional[str] = None, ) -> Resource: try: ctx = FlyteContext.current_context() diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index afd3af1321..1bf0f6a485 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -27,6 +27,7 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, + output_prefix: Optional[str] = None, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) diff --git a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py index 6298bdf52c..12de3da23b 100644 --- a/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py +++ b/plugins/flytekit-openai/tests/chatgpt/test_chatgpt.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from unittest import mock from flytekitplugins.openai import ChatGPTTask @@ -7,6 +8,14 @@ from flytekit.models.types import SimpleType +async def mock_acreate(*args, **kwargs) -> str: + mock_response = mock.MagicMock() + mock_choice = mock.MagicMock() + mock_choice.message.content = "mocked_message" + mock_response.choices = [mock_choice] + return mock_response + + def test_chatgpt_task(): chatgpt_task = ChatGPTTask( name="chatgpt", @@ -40,3 +49,16 @@ def test_chatgpt_task(): assert chatgpt_task_spec.template.interface.inputs["message"].type.simple == SimpleType.STRING assert chatgpt_task_spec.template.interface.outputs["o0"].type.simple == SimpleType.STRING + + with mock.patch("openai.resources.chat.completions.AsyncCompletions.create", new=mock_acreate): + chatgpt_task = ChatGPTTask( + name="chatgpt", + openai_organization="TEST ORGANIZATION ID", + chatgpt_config={ + "model": "gpt-3.5-turbo", + "temperature": 0.7, + }, + ) + + response = chatgpt_task(message="hi") + assert response == "mocked_message" From 5bab3d344a2e707d972b2d8cbcab540524586f4b Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 4 Jun 2024 11:50:31 +0800 Subject: [PATCH 2/7] Use kwargs instead of output_prefix Signed-off-by: Future-Outlier Co-authored-by: pingsutw --- plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py index 1bf0f6a485..e4f24baa5a 100644 --- a/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py +++ b/plugins/flytekit-openai/flytekitplugins/openai/chatgpt/agent.py @@ -27,7 +27,7 @@ async def do( self, task_template: TaskTemplate, inputs: Optional[LiteralMap] = None, - output_prefix: Optional[str] = None, + **kwargs, ) -> Resource: ctx = FlyteContextManager.current_context() input_python_value = TypeEngine.literal_map_to_kwargs(ctx, inputs, {"message": str}) From 7b33296bf72237a59ecb97ecc1b0f720cb9862c1 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 4 Jun 2024 11:53:42 +0800 Subject: [PATCH 3/7] lint Signed-off-by: Future-Outlier --- flytekit/extend/backend/base_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 2ab37e4a08..fda71020a6 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -119,7 +119,9 @@ class SyncAgentBase(AgentBase): name = "Base Sync Agent" @abstractmethod - def do(self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs) -> Resource: + def do( + self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs + ) -> Resource: """ This is the method that the agent will run. """ From ff3c3a67da65c4d30f0082118fbbaa74f093f2bd Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 4 Jun 2024 12:55:52 +0800 Subject: [PATCH 4/7] Fix CI Signed-off-by: Future-Outlier --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 126f05050a..546d034530 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "flyteidl>=1.11.0b1", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", - "googleapis-common-protos>=1.57", + "googleapis-common-protos>=1.57,!=1.63.1", "grpcio", "grpcio-status", "importlib-metadata", From 4d77eef499638ba65a928cd00c349b78b22913de Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Tue, 4 Jun 2024 14:56:38 +0800 Subject: [PATCH 5/7] fix hugging face plugin Signed-off-by: Future-Outlier --- plugins/flytekit-huggingface/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 9c1debaba0..0b1da4bd15 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -6,7 +6,7 @@ plugin_requires = [ "flytekit>=1.3.0b2,<2.0.0", - "datasets>=2.4.0", + "datasets>=2.4.0,<2.19.2", ] __version__ = "0.0.0+develop" From 40406d44210e3e1434e26b823484be93f64132ed Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 06:41:37 +0800 Subject: [PATCH 6/7] update by pingsu's advice Signed-off-by: Future-Outlier --- flytekit/extend/backend/base_agent.py | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index fda71020a6..e8ec18806e 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -120,7 +120,7 @@ class SyncAgentBase(AgentBase): @abstractmethod def do( - self, task_template: TaskTemplate, inputs: Optional[LiteralMap], output_prefix: Optional[str], **kwargs + self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs ) -> Resource: """ This is the method that the agent will run. @@ -263,8 +263,8 @@ async def _do( self: PythonTask, agent: SyncAgentBase, template: TaskTemplate, + output_prefix: str, inputs: Dict[str, Any] = None, - output_prefix: Optional[str] = None, ) -> Resource: try: ctx = FlyteContext.current_context() diff --git a/pyproject.toml b/pyproject.toml index 546d034530..126f05050a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ dependencies = [ "flyteidl>=1.11.0b1", "fsspec>=2023.3.0", "gcsfs>=2023.3.0", - "googleapis-common-protos>=1.57,!=1.63.1", + "googleapis-common-protos>=1.57", "grpcio", "grpcio-status", "importlib-metadata", From ec1f1853b84132936433b362c10dd80b8ff670f2 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Wed, 5 Jun 2024 06:42:44 +0800 Subject: [PATCH 7/7] revert hugging face dataset's change Signed-off-by: Future-Outlier --- plugins/flytekit-huggingface/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-huggingface/setup.py b/plugins/flytekit-huggingface/setup.py index 0b1da4bd15..9c1debaba0 100644 --- a/plugins/flytekit-huggingface/setup.py +++ b/plugins/flytekit-huggingface/setup.py @@ -6,7 +6,7 @@ plugin_requires = [ "flytekit>=1.3.0b2,<2.0.0", - "datasets>=2.4.0,<2.19.2", + "datasets>=2.4.0", ] __version__ = "0.0.0+develop"