From 3aba660417ece9437be2db886f0f5ece84231a82 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 17 Oct 2023 12:28:29 +0200 Subject: [PATCH 01/51] Add context length info. Refactor BuiltinTask and models to facilitate this. --- spacy_llm/models/hf/base.py | 7 ++++ spacy_llm/models/hf/dolly.py | 4 ++ spacy_llm/models/hf/falcon.py | 4 ++ spacy_llm/models/hf/llama2.py | 4 ++ spacy_llm/models/hf/mistral.py | 4 ++ spacy_llm/models/hf/openllama.py | 4 ++ spacy_llm/models/hf/stablelm.py | 4 ++ spacy_llm/models/rest/anthropic/model.py | 34 ++++++++--------- spacy_llm/models/rest/azure/model.py | 34 +++++++++++------ spacy_llm/models/rest/azure/registry.py | 11 +++++- spacy_llm/models/rest/base.py | 16 +++++++- spacy_llm/models/rest/cohere/model.py | 9 +++++ spacy_llm/models/rest/noop/model.py | 9 +++-- spacy_llm/models/rest/openai/model.py | 48 ++++++++++++------------ spacy_llm/models/rest/palm/model.py | 11 ++++-- spacy_llm/tasks/builtin_task.py | 3 +- 16 files changed, 142 insertions(+), 64 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 71fdc074..72ce7446 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -59,6 +59,13 @@ def get_model_names(cls) -> Tuple[str, ...]: """ return tuple(str(arg) for arg in cls.MODEL_NAMES.__args__) # type: ignore[attr-defined] + @property + @abc.abstractmethod + def context_length(self) -> int: + """Returns context length in number of tokens for this model. + RETURNS (int): Max. number of tokens in allowed in prompt for the current model. + """ + @property @abc.abstractmethod def hf_account(self) -> str: diff --git a/spacy_llm/models/hf/dolly.py b/spacy_llm/models/hf/dolly.py index 849f34bd..a1d658fa 100644 --- a/spacy_llm/models/hf/dolly.py +++ b/spacy_llm/models/hf/dolly.py @@ -46,6 +46,10 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run, ) + @property + def context_length(self) -> int: + return 2048 + @registry.llm_models("spacy.Dolly.v1") def dolly_hf( diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 76d4e9e2..64c33ea5 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -63,6 +63,10 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run, ) + @property + def context_length(self) -> int: + return 2048 + @registry.llm_models("spacy.Falcon.v1") def falcon_hf( diff --git a/spacy_llm/models/hf/llama2.py b/spacy_llm/models/hf/llama2.py index f03d00ee..8fc8eef6 100644 --- a/spacy_llm/models/hf/llama2.py +++ b/spacy_llm/models/hf/llama2.py @@ -49,6 +49,10 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return HuggingFace.compile_default_configs() + @property + def context_length(self) -> int: + return 4096 + @registry.llm_models("spacy.Llama2.v1") def llama2_hf( diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 6fe78c78..cd8bbd9b 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -82,6 +82,10 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run, ) + @property + def context_length(self) -> int: + return 8000 + @registry.llm_models("spacy.Mistral.v1") def mistral_hf( diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index 4cf2f4cf..f7e1ff83 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -76,6 +76,10 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: {**default_cfg_run, "max_new_tokens": 32}, ) + @property + def context_length(self) -> int: + return 2048 + @registry.llm_models("spacy.OpenLLaMA.v1") def openllama_hf( diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 4711d69f..028e81e7 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -107,6 +107,10 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: }, ) + @property + def context_length(self) -> int: + return 4096 + @registry.llm_models("spacy.StableLM.v1") def stablelm_hf( diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index 602ba14b..efc7106b 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -108,25 +108,25 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: assert len(api_responses) == len(prompts) return api_responses - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ( + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { # claude-2 - "claude-2", - "claude-2-100k", + "claude-2": 100000, + "claude-2-100k": 100000, # claude-1 - "claude-1", - "claude-1-100k", + "claude-1": 100000, + "claude-1-100k": 100000, # claude-instant-1 - "claude-instant-1", - "claude-instant-1-100k", + "claude-instant-1": 100000, + "claude-instant-1-100k": 100000, # claude-instant-1.1 - "claude-instant-1.1", - "claude-instant-1.1-100k", + "claude-instant-1.1": 100000, + "claude-instant-1.1-100k": 100000, # claude-1.3 - "claude-1.3", - "claude-1.3-100k", + "claude-1.3": 100000, + "claude-1.3-100k": 100000, # others - "claude-1.0", - "claude-1.2", - ) + "claude-1.0": 100000, + "claude-1.2": 100000, + } diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index a30173ff..96617f68 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -18,6 +18,7 @@ class ModelType(str, Enum): class AzureOpenAI(REST): def __init__( self, + deployment_name: str, name: str, endpoint: str, config: Dict[Any, Any], @@ -30,6 +31,7 @@ def __init__( ): self._model_type = model_type self._api_version = api_version + self._deployment_name = deployment_name super().__init__( name=name, endpoint=endpoint, @@ -48,7 +50,7 @@ def endpoint(self) -> str: return ( self._endpoint + ("" if self._endpoint.endswith("/") else "/") - + f"openai/deployments/{self._name}/{self._model_type.value}" + + f"openai/deployments/{self._deployment_name}/{self._model_type.value}" ) @property @@ -102,7 +104,6 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: ) from ex responses = r.json() - # todo check if this is the same if "error" in responses: if self._strict: raise ValueError(f"API call failed: {responses}.") @@ -147,11 +148,22 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: return api_responses - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - # We treat the deployment name as "model name", hence it can be arbitrary. - return ("",) - - def _check_model(self) -> None: - # We treat the deployment name as "model name", hence it can be arbitrary. - pass + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + # gpt-4 + "gpt-4": 8192, + "gpt-4-32k": 32768, + # gpt-3.5 + "gpt-3.5-turbo": 4097, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-instruct": 4097, + # text-davinci + "text-davinci-002": 4097, + "text-davinci-003": 4097, + # others + "code-davinci-002": 8001, + "text-curie-001": 2049, + "text-babbage-001": 2049, + "text-ada-001": 2049, + } diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index 9d88e466..3493e1b4 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -10,6 +10,7 @@ @registry.llm_models("spacy.Azure.v1") def azure_openai( + deployment_name: str, name: str, base_url: str, model_type: ModelType, @@ -22,9 +23,14 @@ def azure_openai( ) -> Callable[[Iterable[str]], Iterable[str]]: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. + Docs on OpenAI models supported by Azure: + https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#model-summary-table-and-region-availability. + config (Dict[Any, Any]): LLM config passed on to the model's initialization. - name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the model used by - that deployment, as deployment names in Azure OpenAI can be arbitrary. + deployment_name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the + model used by that deployment, as deployment names in Azure OpenAI can be arbitrary. + name (str): Name of the model used by this deployment. This is required to infer the context length that can be + assumed for prompting. endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like "https://{prefix}.openai.azure.com/". model_type (ModelType): Whether the deployed model is a text completetion model (e. g. @@ -43,6 +49,7 @@ def azure_openai( DOCS: https://spacy.io/api/large-language-models#models """ return AzureOpenAI( + deployment_name=deployment_name, name=name, endpoint=base_url, config=config, diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index b7dccca3..e89d3928 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -79,11 +79,25 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: """ @classmethod - @abc.abstractmethod def get_model_names(cls) -> Tuple[str, ...]: """Names of supported models. RETURNS (Tuple[str]): Names of supported models. """ + return tuple(cls._get_context_lengths().keys()) + + @staticmethod + @abc.abstractmethod + def _get_context_lengths() -> Dict[str, int]: + """Get context lengths per model name. + RETURNS (Dict[str, int]): Dict with model name -> context length. + """ + + @property + def context_length(self) -> int: + """Returns context length in number of tokens for this model. + RETURNS (int): Max. number of tokens in allowed in prompt for the current model. + """ + return self._get_context_lengths()[self._name] @property @abc.abstractmethod diff --git a/spacy_llm/models/rest/cohere/model.py b/spacy_llm/models/rest/cohere/model.py index 293ed92b..2973945c 100644 --- a/spacy_llm/models/rest/cohere/model.py +++ b/spacy_llm/models/rest/cohere/model.py @@ -115,3 +115,12 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: @classmethod def get_model_names(cls) -> Tuple[str, ...]: return "command", "command-light", "command-light-nightly", "command-nightly" + + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "command": 4096, + "command-light": 4096, + "command-light-nightly": 4096, + "command-nightly": 4096, + } diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index 0e3e0398..cdb46170 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -1,5 +1,6 @@ +import sys import time -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable from ..base import REST @@ -34,6 +35,6 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: time.sleep(NoOpModel._CALL_TIMEOUT) return [_NOOP_RESPONSE] * len(list(prompts)) - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ("NoOp",) + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return {"NoOp": sys.maxsize} diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index a712e082..032ca462 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -141,30 +141,30 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: return api_responses - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ( + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { # gpt-4 - "gpt-4", - "gpt-4-0314", - "gpt-4-32k", - "gpt-4-32k-0314", + "gpt-4": 8192, + "gpt-4-0314": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, # gpt-3.5 - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-0613-16k", - "gpt-3.5-turbo-instruct", + "gpt-3.5-turbo": 4097, + "gpt-3.5-turbo-16k": 16385, + "gpt-3.5-turbo-0613": 4097, + "gpt-3.5-turbo-0613-16k": 16385, + "gpt-3.5-turbo-instruct": 4097, # text-davinci - "text-davinci-002", - "text-davinci-003", + "text-davinci-002": 4097, + "text-davinci-003": 4097, # others - "code-davinci-002", - "text-curie-001", - "text-babbage-001", - "text-ada-001", - "davinci", - "curie", - "babbage", - "ada", - ) + "code-davinci-002": 8001, + "text-curie-001": 2049, + "text-babbage-001": 2049, + "text-ada-001": 2049, + "davinci": 2049, + "curie": 2049, + "babbage": 2049, + "ada": 2049, + } diff --git a/spacy_llm/models/rest/palm/model.py b/spacy_llm/models/rest/palm/model.py index 1e9b10b1..d67ec3a5 100644 --- a/spacy_llm/models/rest/palm/model.py +++ b/spacy_llm/models/rest/palm/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -108,6 +108,9 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: return api_responses - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return "text-bison-001", "chat-bison-001" + @staticmethod + def _get_context_lengths() -> Dict[str, int]: + return { + "text-bison-001": 8192, + "chat-bison-001": 8192, + } diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index fa565a97..152f518b 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -4,7 +4,7 @@ import jinja2 import srsly -from spacy import Language, util, Errors +from spacy import Errors, Language, util from spacy.tokens import Doc from spacy.training import Example @@ -52,6 +52,7 @@ def generate_prompts(self, docs: Iterable[Doc], **kwargs) -> Iterable[Any]: """ environment = jinja2.Environment() _template = environment.from_string(self._template) + for doc in docs: prompt = _template.render( text=doc.text, prompt_examples=self._prompt_examples, **kwargs From 42133727a5bffb8b09bf50da94f2207f3f3568d7 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 17 Oct 2023 15:50:54 +0200 Subject: [PATCH 02/51] Add token count estimator plumbing. --- spacy_llm/tasks/builtin_task.py | 9 +++++++-- spacy_llm/tasks/lemma/registry.py | 7 ++++++- spacy_llm/tasks/lemma/task.py | 5 ++++- spacy_llm/tasks/ner/registry.py | 9 ++++++++- spacy_llm/tasks/ner/task.py | 5 ++++- spacy_llm/tasks/rel/registry.py | 7 ++++++- spacy_llm/tasks/rel/task.py | 5 ++++- spacy_llm/tasks/sentiment/registry.py | 7 ++++++- spacy_llm/tasks/sentiment/task.py | 5 ++++- spacy_llm/tasks/span/task.py | 4 +++- spacy_llm/tasks/spancat/registry.py | 9 ++++++++- spacy_llm/tasks/spancat/task.py | 5 ++++- spacy_llm/tasks/summarization/registry.py | 7 ++++++- spacy_llm/tasks/summarization/task.py | 5 ++++- spacy_llm/tasks/textcat/registry.py | 9 ++++++++- spacy_llm/tasks/textcat/task.py | 5 ++++- spacy_llm/tasks/util/tokenization.py | 15 +++++++++++++++ spacy_llm/tests/models/test_rest.py | 3 ++- spacy_llm/ty.py | 1 + 19 files changed, 104 insertions(+), 18 deletions(-) create mode 100644 spacy_llm/tasks/util/tokenization.py diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 3c690798..0299e7df 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -10,7 +10,7 @@ from ..compat import Self from ..registry import lowercase_normalizer -from ..ty import FewshotExample, TaskResponseParser +from ..ty import FewshotExample, NTokenEstimator, TaskResponseParser class BuiltinTask(abc.ABC): @@ -34,17 +34,20 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, ): """Initializes task. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. """ self._parse_responses = parse_responses self._prompt_examples = prompt_examples or [] self._template = template self._prompt_example_type = prompt_example_type + self._n_token_estimator = n_token_estimator def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: """Generate prompts from docs. @@ -162,7 +165,6 @@ def from_bytes( exclude (Tuple[str]): Names of properties to exclude from deserialization. RETURNS (BuiltinTask): Modified BuiltinTask instance. """ - deserialize = { "cfg": lambda b: self.set_cfg(srsly.json_loads(b)), "prompt_examples": lambda b: self._set_prompt_examples( @@ -245,6 +247,7 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, labels: List[str], label_definitions: Optional[Dict[str, str]], normalizer: Optional[Callable[[str], str]], @@ -255,6 +258,7 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. labels (List[str]): List of labels to pass to the template. Leave empty to (optionally) populate it at initialization time. label_definitions (Optional[Dict[str, str]]): Map of label -> description @@ -268,6 +272,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, ) self._normalizer = normalizer if normalizer else lowercase_normalizer() self._label_dict = { diff --git a/spacy_llm/tasks/lemma/registry.py b/spacy_llm/tasks/lemma/registry.py index e317e280..4695b7c6 100644 --- a/spacy_llm/tasks/lemma/registry.py +++ b/spacy_llm/tasks/lemma/registry.py @@ -1,7 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer +from ...ty import TaskResponseParser +from ..util.tokenization import make_default_n_token_estimator from .parser import parse_responses_v1 from .task import DEFAULT_LEMMA_TEMPLATE_V1, LemmaTask from .util import LemmaExample, score @@ -23,6 +25,7 @@ def make_lemma_task( parse_responses: Optional[TaskResponseParser[LemmaTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, scorer: Optional[Scorer] = None, ): """Lemma.v1 task factory. @@ -32,6 +35,7 @@ def make_lemma_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples @@ -45,5 +49,6 @@ def make_lemma_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=lemma_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), scorer=scorer or score, ) diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index c3bb2083..78c1d426 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -5,7 +5,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +19,7 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], prompt_examples: Optional[List[FewshotExample[Self]]], template: str, + n_token_estimator: NTokenEstimator, scorer: Scorer, ): """Default lemmatization task. @@ -27,6 +28,7 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. scorer (Scorer): Scorer function. """ super().__init__( @@ -34,6 +36,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, ) self._scorer = scorer diff --git a/spacy_llm/tasks/ner/registry.py b/spacy_llm/tasks/ner/registry.py index 55b8e2ce..3c350471 100644 --- a/spacy_llm/tasks/ner/registry.py +++ b/spacy_llm/tasks/ner/registry.py @@ -2,11 +2,13 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency, check_label_consistency_cot +from ..util.tokenization import make_default_n_token_estimator from .task import DEFAULT_NER_TEMPLATE_V1, DEFAULT_NER_TEMPLATE_V2 from .task import DEFAULT_NER_TEMPLATE_V3, NERTask, SpanTask from .util import NERCoTExample, NERExample, score @@ -51,6 +53,7 @@ def make_ner_task( labels=labels_list, template=DEFAULT_NER_TEMPLATE_V1, prompt_examples=span_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -111,6 +114,7 @@ def make_ner_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -129,6 +133,7 @@ def make_ner_task_v3( template: str = DEFAULT_NER_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -150,6 +155,7 @@ def make_ner_task_v3( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -169,6 +175,7 @@ def make_ner_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/ner/task.py b/spacy_llm/tasks/ner/task.py index 7cff6523..703ca1ca 100644 --- a/spacy_llm/tasks/ner/task.py +++ b/spacy_llm/tasks/ner/task.py @@ -6,7 +6,7 @@ from spacy.util import filter_spans from ...compat import Literal, Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -25,6 +25,7 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -40,6 +41,7 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. label_definitions (Optional[Dict[str, str]]): Map of label -> description of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than @@ -59,6 +61,7 @@ def __init__( template=template, parse_responses=parse_responses, prompt_example_type=prompt_example_type, + n_token_estimator=n_token_estimator, label_definitions=label_definitions, prompt_examples=prompt_examples, normalizer=normalizer, diff --git a/spacy_llm/tasks/rel/registry.py b/spacy_llm/tasks/rel/registry.py index 2a3121fb..d076af2a 100644 --- a/spacy_llm/tasks/rel/registry.py +++ b/spacy_llm/tasks/rel/registry.py @@ -1,8 +1,10 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator +from ...ty import TaskResponseParser from ...util import split_labels +from ..util.tokenization import make_default_n_token_estimator from .examples import RELExample from .parser import parse_responses_v1 from .task import DEFAULT_REL_TEMPLATE, RELTask @@ -16,6 +18,7 @@ def make_rel_task( prompt_example_type: Optional[Type[FewshotExample]] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, normalizer: Optional[Callable[[str], str]] = None, verbose: bool = False, ) -> "RELTask": @@ -35,6 +38,7 @@ def make_rel_task( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -50,6 +54,7 @@ def make_rel_task( template=template, label_definitions=label_definitions, prompt_examples=rel_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), normalizer=normalizer, verbose=verbose, ) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index 83accb81..c917cda9 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -5,7 +5,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template from .util import EntityItem, RelationItem @@ -22,6 +22,7 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, normalizer: Optional[Callable[[str], str]], verbose: bool, ): @@ -37,6 +38,7 @@ def __init__( It is usually easier to provide these definitions rather than full examples, although both can be provided. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -45,6 +47,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/sentiment/registry.py b/spacy_llm/tasks/sentiment/registry.py index 6dd51606..fd1a6094 100644 --- a/spacy_llm/tasks/sentiment/registry.py +++ b/spacy_llm/tasks/sentiment/registry.py @@ -1,7 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator +from ...ty import TaskResponseParser +from ..util.tokenization import make_default_n_token_estimator from .parser import parse_responses_v1 from .task import DEFAULT_SENTIMENT_TEMPLATE_V1, SentimentTask from .util import SentimentExample @@ -13,6 +15,7 @@ def make_sentiment_task( parse_responses: Optional[TaskResponseParser[SentimentTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, field: str = "sentiment", ): """Sentiment.v1 task factory. @@ -23,6 +26,7 @@ def make_sentiment_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. field (str): The name of the doc extension in which to store the summary. """ raw_examples = examples() if callable(examples) else examples @@ -36,5 +40,6 @@ def make_sentiment_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=sentiment_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), field=field, ) diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index ab34b4dd..2a1d877c 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -4,7 +4,7 @@ from spacy.tokens import Doc from spacy.training import Example -from ...ty import FewshotExample, Self, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Self, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +19,7 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, ): """Sentiment analysis task. @@ -27,12 +28,14 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. field (str): The name of the doc extension in which to store the sentiment score. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. """ super().__init__( parse_responses=parse_responses, prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, ) self._field = field self._check_doc_extension() diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index f19f4d20..28ca3243 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -5,7 +5,7 @@ from spacy.tokens import Doc, Span from ...compat import Literal, Protocol, Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from . import SpanExample from .examples import SpanCoTExample @@ -33,6 +33,7 @@ def __init__( prompt_examples: Optional[ Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] ], + n_token_estimator: NTokenEstimator, description: Optional[str], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 @@ -46,6 +47,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/spancat/registry.py b/spacy_llm/tasks/spancat/registry.py index 33cf11dd..fcea4806 100644 --- a/spacy_llm/tasks/spancat/registry.py +++ b/spacy_llm/tasks/spancat/registry.py @@ -2,12 +2,14 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency as check_labels from ..span.util import check_label_consistency_cot as check_labels_cot +from ..util.tokenization import make_default_n_token_estimator from .task import DEFAULT_SPANCAT_TEMPLATE_V1, DEFAULT_SPANCAT_TEMPLATE_V2 from .task import DEFAULT_SPANCAT_TEMPLATE_V3, SpanCatTask from .util import SpanCatCoTExample, SpanCatExample, score @@ -55,6 +57,7 @@ def make_spancat_task( prompt_example_type=example_type, template=DEFAULT_SPANCAT_TEMPLATE_V1, prompt_examples=span_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -119,6 +122,7 @@ def make_spancat_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -139,6 +143,7 @@ def make_spancat_task_v3( description: Optional[str] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -161,6 +166,7 @@ def make_spancat_task_v3( full examples, although both can be provided. examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -181,6 +187,7 @@ def make_spancat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/spancat/task.py b/spacy_llm/tasks/spancat/task.py index 76439964..7df5c6a8 100644 --- a/spacy_llm/tasks/spancat/task.py +++ b/spacy_llm/tasks/spancat/task.py @@ -5,7 +5,7 @@ from spacy.training import Example from ...compat import Literal, Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -25,6 +25,7 @@ def __init__( label_definitions: Optional[Dict[str, str]], spans_key: str, prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -46,6 +47,7 @@ def __init__( full examples, although both can be provided. spans_key (str): Key of the `Doc.spans` dict to save under. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -62,6 +64,7 @@ def __init__( template=template, label_definitions=label_definitions, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/summarization/registry.py b/spacy_llm/tasks/summarization/registry.py index 216d99bf..310e84ce 100644 --- a/spacy_llm/tasks/summarization/registry.py +++ b/spacy_llm/tasks/summarization/registry.py @@ -1,7 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator +from ...ty import TaskResponseParser +from ..util.tokenization import make_default_n_token_estimator from .parser import parse_responses_v1 from .task import DEFAULT_SUMMARIZATION_TEMPLATE_V1, SummarizationTask from .util import SummarizationExample @@ -13,6 +15,7 @@ def make_summarization_task( parse_responses: Optional[TaskResponseParser[SummarizationTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, max_n_words: Optional[int] = None, field: str = "summary", ): @@ -24,6 +27,7 @@ def make_summarization_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. max_n_words (int): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. """ @@ -38,6 +42,7 @@ def make_summarization_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=span_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), max_n_words=max_n_words, field=field, ) diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index cc749ab3..435a925d 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -6,7 +6,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -19,6 +19,7 @@ def __init__( parse_responses: TaskResponseParser[Self], prompt_example_type: Type[FewshotExample[Self]], template: str, + n_token_estimator: NTokenEstimator, max_n_words: Optional[int], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], @@ -28,6 +29,7 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. max_n_words (Optional[int]): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. @@ -37,6 +39,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, ) self._max_n_words = max_n_words self._field = field diff --git a/spacy_llm/tasks/textcat/registry.py b/spacy_llm/tasks/textcat/registry.py index 7f97709c..a9214320 100644 --- a/spacy_llm/tasks/textcat/registry.py +++ b/spacy_llm/tasks/textcat/registry.py @@ -1,8 +1,10 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer +from ...ty import TaskResponseParser from ...util import split_labels +from ..util.tokenization import make_default_n_token_estimator from .parser import parse_responses_v1_v2_v3 from .task import DEFAULT_TEXTCAT_TEMPLATE_V1, DEFAULT_TEXTCAT_TEMPLATE_V2 from .task import DEFAULT_TEXTCAT_TEMPLATE_V3, TextCatTask @@ -62,6 +64,7 @@ def make_textcat_task( labels=labels_list, template=DEFAULT_TEXTCAT_TEMPLATE_V1, prompt_examples=textcat_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -128,6 +131,7 @@ def make_textcat_task_v2( labels=labels_list, template=template, prompt_examples=textcat_examples, + n_token_estimator=make_default_n_token_estimator(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -145,6 +149,7 @@ def make_textcat_task_v3( template: str = DEFAULT_TEXTCAT_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, + n_token_estimator: Optional[NTokenEstimator] = None, normalizer: Optional[Callable[[str], str]] = None, exclusive_classes: bool = False, allow_none: bool = True, @@ -177,6 +182,7 @@ def make_textcat_task_v3( These descriptions are added to the prompt to help instruct the LLM on what to extract. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -199,6 +205,7 @@ def make_textcat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=textcat_examples, + n_token_estimator=n_token_estimator or make_default_n_token_estimator(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 63d7ec7e..8439cb6d 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -6,7 +6,7 @@ from wasabi import msg from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template @@ -24,6 +24,7 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], + n_token_estimator: NTokenEstimator, normalizer: Optional[Callable[[str], str]], exclusive_classes: bool, allow_none: bool, @@ -53,6 +54,7 @@ def __init__( label_definitions (Optional[Dict[str, str]]): Optional dict mapping a label to a description of that label. These descriptions are added to the prompt to help instruct the LLM on what to extract. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -65,6 +67,7 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + n_token_estimator=n_token_estimator, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/util/tokenization.py b/spacy_llm/tasks/util/tokenization.py new file mode 100644 index 00000000..4e47fb89 --- /dev/null +++ b/spacy_llm/tasks/util/tokenization.py @@ -0,0 +1,15 @@ +from ...registry import registry +from ...ty import NTokenEstimator + + +@registry.llm_misc("spacy.NTokenEstimator.v1") +def make_default_n_token_estimator() -> NTokenEstimator: + """Generates Callable estimating the number of tokens in a given string. + # todo improve default tokenization (allow language code to do tokenization with pretrained spacy model) + RETURNS (NTokenEstimator): Callable estimating the number of tokens in a given string. + """ + + def count_tokens_by_spaces(value: str) -> int: + return len(value.split()) + + return count_tokens_by_spaces diff --git a/spacy_llm/tests/models/test_rest.py b/spacy_llm/tests/models/test_rest.py index dc0210b7..035cba5c 100644 --- a/spacy_llm/tests/models/test_rest.py +++ b/spacy_llm/tests/models/test_rest.py @@ -120,7 +120,8 @@ def test_azure_openai(deployment_name: str): "@llm_models": "spacy.Azure.v1", "base_url": "https://explosion.openai.azure.com/", "model_type": "completions", - "name": deployment_name, + "deployment_name": deployment_name, + "name": deployment_name.replace("35", "3.5"), }, "task": {"@llm_tasks": "spacy.NoOp.v1"}, "save_io": True, diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 7052c197..0b3aee90 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -21,6 +21,7 @@ ExamplesConfigType = Union[ Iterable[Dict[str, Any]], Callable[[], Iterable[Dict[str, Any]]], None ] +NTokenEstimator = Callable[[str], int] @runtime_checkable From f440ca41bbbd258a80a03779a8c9229cc32c4c6d Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 17 Oct 2023 17:13:20 +0200 Subject: [PATCH 03/51] Add plumbing for mapper and reducer. --- spacy_llm/tasks/builtin_task.py | 13 ++++++++- spacy_llm/tasks/lemma/registry.py | 19 ++++++++++--- spacy_llm/tasks/lemma/task.py | 9 +++++- spacy_llm/tasks/lemma/util.py | 10 +++++++ spacy_llm/tasks/ner/registry.py | 27 ++++++++++++++---- spacy_llm/tasks/ner/task.py | 9 +++++- spacy_llm/tasks/ner/util.py | 10 +++++++ spacy_llm/tasks/rel/registry.py | 20 ++++++++++--- spacy_llm/tasks/rel/task.py | 9 +++++- spacy_llm/tasks/rel/util.py | 13 +++++++++ spacy_llm/tasks/sentiment/registry.py | 21 ++++++++++---- spacy_llm/tasks/sentiment/task.py | 9 +++++- spacy_llm/tasks/sentiment/util.py | 12 +++++++- spacy_llm/tasks/span/task.py | 7 ++++- spacy_llm/tasks/spancat/registry.py | 27 ++++++++++++++---- spacy_llm/tasks/spancat/task.py | 9 +++++- spacy_llm/tasks/spancat/util.py | 10 +++++++ spacy_llm/tasks/summarization/registry.py | 21 ++++++++++---- spacy_llm/tasks/summarization/task.py | 9 +++++- spacy_llm/tasks/summarization/util.py | 12 +++++++- spacy_llm/tasks/textcat/registry.py | 27 ++++++++++++++---- spacy_llm/tasks/textcat/task.py | 9 +++++- spacy_llm/tasks/textcat/util.py | 10 +++++++ spacy_llm/tasks/util/sharding.py | 34 +++++++++++++++++++++++ spacy_llm/tasks/util/tokenization.py | 15 ---------- spacy_llm/ty.py | 7 +++++ 26 files changed, 316 insertions(+), 62 deletions(-) create mode 100644 spacy_llm/tasks/util/sharding.py delete mode 100644 spacy_llm/tasks/util/tokenization.py diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 0299e7df..97bf1467 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -10,7 +10,8 @@ from ..compat import Self from ..registry import lowercase_normalizer -from ..ty import FewshotExample, NTokenEstimator, TaskResponseParser +from ..ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer +from ..ty import TaskResponseParser class BuiltinTask(abc.ABC): @@ -35,6 +36,8 @@ def __init__( template: str, prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, ): """Initializes task. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. @@ -42,6 +45,8 @@ def __init__( template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. """ self._parse_responses = parse_responses self._prompt_examples = prompt_examples or [] @@ -248,6 +253,8 @@ def __init__( template: str, prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, labels: List[str], label_definitions: Optional[Dict[str, str]], normalizer: Optional[Callable[[str], str]], @@ -259,6 +266,8 @@ def __init__( template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. labels (List[str]): List of labels to pass to the template. Leave empty to (optionally) populate it at initialization time. label_definitions (Optional[Dict[str, str]]): Map of label -> description @@ -273,6 +282,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._normalizer = normalizer if normalizer else lowercase_normalizer() self._label_dict = { diff --git a/spacy_llm/tasks/lemma/registry.py b/spacy_llm/tasks/lemma/registry.py index 4695b7c6..fbaf488a 100644 --- a/spacy_llm/tasks/lemma/registry.py +++ b/spacy_llm/tasks/lemma/registry.py @@ -2,11 +2,11 @@ from ...registry import registry from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import TaskResponseParser -from ..util.tokenization import make_default_n_token_estimator +from ...ty import ShardMapper, ShardReducer, TaskResponseParser +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_LEMMA_TEMPLATE_V1, LemmaTask -from .util import LemmaExample, score +from .util import LemmaExample, reduce_shards_to_doc, score @registry.llm_misc("spacy.LemmaParser.v1") @@ -19,6 +19,11 @@ def make_lemma_scorer() -> Scorer: return score +@registry.llm_misc("spacy.LemmaShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc + + @registry.llm_tasks("spacy.Lemma.v1") def make_lemma_task( template: str = DEFAULT_LEMMA_TEMPLATE_V1, @@ -26,6 +31,8 @@ def make_lemma_task( prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, scorer: Optional[Scorer] = None, ): """Lemma.v1 task factory. @@ -36,6 +43,8 @@ def make_lemma_task( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples @@ -49,6 +58,8 @@ def make_lemma_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=lemma_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), scorer=scorer or score, ) diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index 78c1d426..d7ff08a1 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -5,7 +5,8 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,6 +21,8 @@ def __init__( prompt_examples: Optional[List[FewshotExample[Self]]], template: str, n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, scorer: Scorer, ): """Default lemmatization task. @@ -29,6 +32,8 @@ def __init__( prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. """ super().__init__( @@ -37,6 +42,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._scorer = scorer diff --git a/spacy_llm/tasks/lemma/util.py b/spacy_llm/tasks/lemma/util.py index fde27498..c04c2696 100644 --- a/spacy_llm/tasks/lemma/util.py +++ b/spacy_llm/tasks/lemma/util.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Iterable, List, Optional from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -24,3 +25,12 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: RETURNS (Dict[str, Any]): Dict with metric name -> score. """ return Scorer.score_token_attr(examples, "lemma") + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for LemmaTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/ner/registry.py b/spacy_llm/tasks/ner/registry.py index 3c350471..dbe372ee 100644 --- a/spacy_llm/tasks/ner/registry.py +++ b/spacy_llm/tasks/ner/registry.py @@ -3,15 +3,20 @@ from ...compat import Literal from ...registry import registry from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import TaskResponseParser +from ...ty import ShardMapper, ShardReducer, TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency, check_label_consistency_cot -from ..util.tokenization import make_default_n_token_estimator +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .task import DEFAULT_NER_TEMPLATE_V1, DEFAULT_NER_TEMPLATE_V2 from .task import DEFAULT_NER_TEMPLATE_V3, NERTask, SpanTask -from .util import NERCoTExample, NERExample, score +from .util import NERCoTExample, NERExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.NERShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.NER.v1") @@ -53,7 +58,9 @@ def make_ner_task( labels=labels_list, template=DEFAULT_NER_TEMPLATE_V1, prompt_examples=span_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -114,7 +121,9 @@ def make_ner_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -134,6 +143,8 @@ def make_ner_task_v3( label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -156,6 +167,8 @@ def make_ner_task_v3( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -175,7 +188,9 @@ def make_ner_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/ner/task.py b/spacy_llm/tasks/ner/task.py index 703ca1ca..ec57bae6 100644 --- a/spacy_llm/tasks/ner/task.py +++ b/spacy_llm/tasks/ner/task.py @@ -6,7 +6,8 @@ from spacy.util import filter_spans from ...compat import Literal, Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -26,6 +27,8 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -42,6 +45,8 @@ def __init__( parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. label_definitions (Optional[Dict[str, str]]): Map of label -> description of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than @@ -62,6 +67,8 @@ def __init__( parse_responses=parse_responses, prompt_example_type=prompt_example_type, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, label_definitions=label_definitions, prompt_examples=prompt_examples, normalizer=normalizer, diff --git a/spacy_llm/tasks/ner/util.py b/spacy_llm/tasks/ner/util.py index d02b9a83..4bc8d09d 100644 --- a/spacy_llm/tasks/ner/util.py +++ b/spacy_llm/tasks/ner/util.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, Optional from spacy.scorer import get_ner_prf +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -35,3 +36,12 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: RETURNS (Dict[str, Any]): Dict with metric name -> score. """ return get_ner_prf(examples) + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for NERTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/rel/registry.py b/spacy_llm/tasks/rel/registry.py index d076af2a..ef2dafe5 100644 --- a/spacy_llm/tasks/rel/registry.py +++ b/spacy_llm/tasks/rel/registry.py @@ -1,13 +1,19 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator -from ...ty import TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper +from ...ty import ShardReducer, TaskResponseParser from ...util import split_labels -from ..util.tokenization import make_default_n_token_estimator +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .examples import RELExample from .parser import parse_responses_v1 from .task import DEFAULT_REL_TEMPLATE, RELTask +from .util import reduce_shards_to_doc + + +@registry.llm_misc("spacy.RELShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.REL.v1") @@ -19,6 +25,8 @@ def make_rel_task( label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, verbose: bool = False, ) -> "RELTask": @@ -39,6 +47,8 @@ def make_rel_task( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -54,7 +64,9 @@ def make_rel_task( template=template, label_definitions=label_definitions, prompt_examples=rel_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, verbose=verbose, ) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index c917cda9..87393a6e 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -5,7 +5,8 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template from .util import EntityItem, RelationItem @@ -23,6 +24,8 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], verbose: bool, ): @@ -39,6 +42,8 @@ def __init__( full examples, although both can be provided. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -48,6 +53,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index 7426d8b8..956d4233 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -1,3 +1,7 @@ +from typing import Iterable + +from spacy.tokens import Doc + from ...compat import BaseModel, validator @@ -17,3 +21,12 @@ class EntityItem(BaseModel): start_char: int end_char: int label: str + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for RELTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/sentiment/registry.py b/spacy_llm/tasks/sentiment/registry.py index fd1a6094..36b496e4 100644 --- a/spacy_llm/tasks/sentiment/registry.py +++ b/spacy_llm/tasks/sentiment/registry.py @@ -1,12 +1,17 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator -from ...ty import TaskResponseParser -from ..util.tokenization import make_default_n_token_estimator +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper +from ...ty import ShardReducer, TaskResponseParser +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SENTIMENT_TEMPLATE_V1, SentimentTask -from .util import SentimentExample +from .util import SentimentExample, reduce_shards_to_doc + + +@registry.llm_misc("spacy.SentimentShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.Sentiment.v1") @@ -16,6 +21,8 @@ def make_sentiment_task( prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, field: str = "sentiment", ): """Sentiment.v1 task factory. @@ -27,6 +34,8 @@ def make_sentiment_task( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. field (str): The name of the doc extension in which to store the summary. """ raw_examples = examples() if callable(examples) else examples @@ -40,6 +49,8 @@ def make_sentiment_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=sentiment_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), field=field, ) diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 2a1d877c..b75e6cc1 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -4,7 +4,8 @@ from spacy.tokens import Doc from spacy.training import Example -from ...ty import FewshotExample, NTokenEstimator, Self, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Self, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,6 +21,8 @@ def __init__( field: str, prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, ): """Sentiment analysis task. @@ -29,6 +32,8 @@ def __init__( field (str): The name of the doc extension in which to store the sentiment score. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. """ super().__init__( parse_responses=parse_responses, @@ -36,6 +41,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._field = field self._check_doc_extension() diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 72adbedb..45e309ee 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Iterable, Optional +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -17,3 +18,12 @@ def generate(cls, example: Example, task: SentimentTask) -> Optional[Self]: text=example.reference.text, score=getattr(example.reference._, task.field), ) + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SentimentTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index 28ca3243..c15ee884 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -5,7 +5,8 @@ from spacy.tokens import Doc, Span from ...compat import Literal, Protocol, Self -from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from . import SpanExample from .examples import SpanCoTExample @@ -34,6 +35,8 @@ def __init__( Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] ], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, description: Optional[str], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 @@ -48,6 +51,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/spancat/registry.py b/spacy_llm/tasks/spancat/registry.py index fcea4806..0d3c648f 100644 --- a/spacy_llm/tasks/spancat/registry.py +++ b/spacy_llm/tasks/spancat/registry.py @@ -3,16 +3,21 @@ from ...compat import Literal from ...registry import registry from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import TaskResponseParser +from ...ty import ShardMapper, ShardReducer, TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency as check_labels from ..span.util import check_label_consistency_cot as check_labels_cot -from ..util.tokenization import make_default_n_token_estimator +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .task import DEFAULT_SPANCAT_TEMPLATE_V1, DEFAULT_SPANCAT_TEMPLATE_V2 from .task import DEFAULT_SPANCAT_TEMPLATE_V3, SpanCatTask -from .util import SpanCatCoTExample, SpanCatExample, score +from .util import SpanCatCoTExample, SpanCatExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.SpanCatShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.SpanCat.v1") @@ -57,7 +62,9 @@ def make_spancat_task( prompt_example_type=example_type, template=DEFAULT_SPANCAT_TEMPLATE_V1, prompt_examples=span_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -122,7 +129,9 @@ def make_spancat_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, @@ -144,6 +153,8 @@ def make_spancat_task_v3( label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, alignment_mode: Literal["strict", "contract", "expand"] = "contract", case_sensitive_matching: bool = False, @@ -167,6 +178,8 @@ def make_spancat_task_v3( examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -187,7 +200,9 @@ def make_spancat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/spancat/task.py b/spacy_llm/tasks/spancat/task.py index 7df5c6a8..08714a02 100644 --- a/spacy_llm/tasks/spancat/task.py +++ b/spacy_llm/tasks/spancat/task.py @@ -5,7 +5,8 @@ from spacy.training import Example from ...compat import Literal, Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -26,6 +27,8 @@ def __init__( spans_key: str, prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -48,6 +51,8 @@ def __init__( spans_key (str): Key of the `Doc.spans` dict to save under. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. @@ -65,6 +70,8 @@ def __init__( label_definitions=label_definitions, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, normalizer=normalizer, alignment_mode=alignment_mode, case_sensitive_matching=case_sensitive_matching, diff --git a/spacy_llm/tasks/spancat/util.py b/spacy_llm/tasks/spancat/util.py index 6ffcd54c..c83e5fe3 100644 --- a/spacy_llm/tasks/spancat/util.py +++ b/spacy_llm/tasks/spancat/util.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Iterable, Optional from spacy.pipeline.spancat import spancat_score +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -41,3 +42,12 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: spans_key=kwargs["spans_key"], allow_overlap=True, ) + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SpanCatTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/summarization/registry.py b/spacy_llm/tasks/summarization/registry.py index 310e84ce..a46fbb21 100644 --- a/spacy_llm/tasks/summarization/registry.py +++ b/spacy_llm/tasks/summarization/registry.py @@ -1,12 +1,17 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator -from ...ty import TaskResponseParser -from ..util.tokenization import make_default_n_token_estimator +from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper +from ...ty import ShardReducer, TaskResponseParser +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SUMMARIZATION_TEMPLATE_V1, SummarizationTask -from .util import SummarizationExample +from .util import SummarizationExample, reduce_shards_to_doc + + +@registry.llm_misc("spacy.SummarizationShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.Summarization.v1") @@ -16,6 +21,8 @@ def make_summarization_task( prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, max_n_words: Optional[int] = None, field: str = "summary", ): @@ -28,6 +35,8 @@ def make_summarization_task( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. max_n_words (int): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. """ @@ -42,7 +51,9 @@ def make_summarization_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), max_n_words=max_n_words, field=field, ) diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 435a925d..0d233c35 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -6,7 +6,8 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,6 +21,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, max_n_words: Optional[int], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], @@ -30,6 +33,8 @@ def __init__( parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. max_n_words (Optional[int]): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. @@ -40,6 +45,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._max_n_words = max_n_words self._field = field diff --git a/spacy_llm/tasks/summarization/util.py b/spacy_llm/tasks/summarization/util.py index 12fd1aa9..ff68ddf6 100644 --- a/spacy_llm/tasks/summarization/util.py +++ b/spacy_llm/tasks/summarization/util.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Iterable, Optional +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -17,3 +18,12 @@ def generate(cls, example: Example, task: SummarizationTask) -> Optional[Self]: text=example.reference.text, summary=getattr(example.reference._, task.field), ) + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SummarizationTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/textcat/registry.py b/spacy_llm/tasks/textcat/registry.py index a9214320..4f9921af 100644 --- a/spacy_llm/tasks/textcat/registry.py +++ b/spacy_llm/tasks/textcat/registry.py @@ -2,13 +2,18 @@ from ...registry import registry from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import TaskResponseParser +from ...ty import ShardMapper, ShardReducer, TaskResponseParser from ...util import split_labels -from ..util.tokenization import make_default_n_token_estimator +from ..util.sharding import make_n_token_estimator, make_shard_mapper from .parser import parse_responses_v1_v2_v3 from .task import DEFAULT_TEXTCAT_TEMPLATE_V1, DEFAULT_TEXTCAT_TEMPLATE_V2 from .task import DEFAULT_TEXTCAT_TEMPLATE_V3, TextCatTask -from .util import TextCatExample, score +from .util import TextCatExample, reduce_shards_to_doc, score + + +@registry.llm_misc("spacy.TextCatShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc @registry.llm_tasks("spacy.TextCat.v1") @@ -64,7 +69,9 @@ def make_textcat_task( labels=labels_list, template=DEFAULT_TEXTCAT_TEMPLATE_V1, prompt_examples=textcat_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -131,7 +138,9 @@ def make_textcat_task_v2( labels=labels_list, template=template, prompt_examples=textcat_examples, - n_token_estimator=make_default_n_token_estimator(), + n_token_estimator=make_n_token_estimator(), + shard_mapper=make_shard_mapper(), + shard_reducer=make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, @@ -150,6 +159,8 @@ def make_textcat_task_v3( label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, n_token_estimator: Optional[NTokenEstimator] = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, exclusive_classes: bool = False, allow_none: bool = True, @@ -183,6 +194,8 @@ def make_textcat_task_v3( examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -205,7 +218,9 @@ def make_textcat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=textcat_examples, - n_token_estimator=n_token_estimator or make_default_n_token_estimator(), + n_token_estimator=n_token_estimator or make_n_token_estimator(), + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, exclusive_classes=exclusive_classes, allow_none=allow_none, diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 8439cb6d..40c81692 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -6,7 +6,8 @@ from wasabi import msg from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, TaskResponseParser +from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template @@ -25,6 +26,8 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], n_token_estimator: NTokenEstimator, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], exclusive_classes: bool, allow_none: bool, @@ -55,6 +58,8 @@ def __init__( These descriptions are added to the prompt to help instruct the LLM on what to extract. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -68,6 +73,8 @@ def __init__( template=template, prompt_examples=prompt_examples, n_token_estimator=n_token_estimator, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, labels=labels, label_definitions=label_definitions, normalizer=normalizer, diff --git a/spacy_llm/tasks/textcat/util.py b/spacy_llm/tasks/textcat/util.py index 992c9bb2..7c4de62d 100644 --- a/spacy_llm/tasks/textcat/util.py +++ b/spacy_llm/tasks/textcat/util.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Iterable, Optional from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -46,3 +47,12 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: labels=kwargs["labels"], multi_label=kwargs["multi_label"], ) + + +def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for TextCatTask. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # todo this is yet a dummy implementation that will only return the first doc shard. + return list(shards)[0] diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py new file mode 100644 index 00000000..f709cb26 --- /dev/null +++ b/spacy_llm/tasks/util/sharding.py @@ -0,0 +1,34 @@ +from typing import Callable + +from spacy.tokens import Doc + +from ...registry import registry +from ...ty import NTokenEstimator, ShardMapper + + +@registry.llm_misc("spacy.NTokenEstimator.v1") +def make_n_token_estimator() -> NTokenEstimator: + """Generates Callable estimating the number of tokens in a given string. + # todo improve default tokenization (allow language code to do tokenization with pretrained spacy model) + RETURNS (NTokenEstimator): Callable estimating the number of tokens in a given string. + """ + + def count_tokens_by_spaces(value: str) -> int: + return len(value.split()) + + return count_tokens_by_spaces + + +@registry.llm_misc("spacy.ShardMapper.v1") +def make_shard_mapper() -> ShardMapper: + """Generates Callable mapping doc to doc shards fitting within context length. + RETURNS (ShardMapper): Callable mapping doc to doc shards fitting within context length. + """ + + def map_doc_to_shards( + doc: Doc, context_length: int, render_template: Callable[[str], str] + ): + # todo this is yet a dummy implementation that will fail for texts with len(text) > context length. + return [doc] + + return map_doc_to_shards diff --git a/spacy_llm/tasks/util/tokenization.py b/spacy_llm/tasks/util/tokenization.py deleted file mode 100644 index 4e47fb89..00000000 --- a/spacy_llm/tasks/util/tokenization.py +++ /dev/null @@ -1,15 +0,0 @@ -from ...registry import registry -from ...ty import NTokenEstimator - - -@registry.llm_misc("spacy.NTokenEstimator.v1") -def make_default_n_token_estimator() -> NTokenEstimator: - """Generates Callable estimating the number of tokens in a given string. - # todo improve default tokenization (allow language code to do tokenization with pretrained spacy model) - RETURNS (NTokenEstimator): Callable estimating the number of tokens in a given string. - """ - - def count_tokens_by_spaces(value: str) -> int: - return len(value.split()) - - return count_tokens_by_spaces diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 0b3aee90..73013f88 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -22,6 +22,13 @@ Iterable[Dict[str, Any]], Callable[[], Iterable[Dict[str, Any]]], None ] NTokenEstimator = Callable[[str], int] +ShardMapper = Callable[ + # Requires doc, context length and callable for rendering template from doc shard text. + [Doc, int, Callable[[str], str]], + # Returns each shard as a doc. + Iterable[Doc], +] +ShardReducer = Callable[[Iterable[Doc]], Doc] @runtime_checkable From e47f762e0aa72652b14e838936e6e5cda2311b35 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 18 Oct 2023 16:40:03 +0200 Subject: [PATCH 04/51] Add ShardMapper prototype. --- spacy_llm/tasks/builtin_task.py | 37 ++++++++------ spacy_llm/tasks/lemma/registry.py | 8 ++- spacy_llm/tasks/lemma/task.py | 6 +-- spacy_llm/tasks/ner/registry.py | 11 ++-- spacy_llm/tasks/ner/task.py | 6 +-- spacy_llm/tasks/rel/registry.py | 9 ++-- spacy_llm/tasks/rel/task.py | 34 ++++++------- spacy_llm/tasks/sentiment/registry.py | 9 ++-- spacy_llm/tasks/sentiment/task.py | 9 ++-- spacy_llm/tasks/span/task.py | 8 +-- spacy_llm/tasks/spancat/registry.py | 11 ++-- spacy_llm/tasks/spancat/task.py | 6 +-- spacy_llm/tasks/summarization/registry.py | 9 ++-- spacy_llm/tasks/summarization/task.py | 12 +---- spacy_llm/tasks/textcat/registry.py | 11 ++-- spacy_llm/tasks/textcat/task.py | 9 +--- spacy_llm/tasks/util/sharding.py | 62 +++++++++++++++++++++-- spacy_llm/ty.py | 2 +- 18 files changed, 128 insertions(+), 131 deletions(-) diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 97bf1467..bdd773bc 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -10,8 +10,7 @@ from ..compat import Self from ..registry import lowercase_normalizer -from ..ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer -from ..ty import TaskResponseParser +from ..ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser class BuiltinTask(abc.ABC): @@ -35,7 +34,6 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, ): @@ -44,7 +42,6 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. """ @@ -52,7 +49,8 @@ def __init__( self._prompt_examples = prompt_examples or [] self._template = template self._prompt_example_type = prompt_example_type - self._n_token_estimator = n_token_estimator + self._shard_mapper = shard_mapper + self._shard_reducer = shard_reducer def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: """Generate prompts from docs. @@ -61,17 +59,32 @@ def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: """ environment = jinja2.Environment() _template = environment.from_string(self._template) + + def render_template(shard: Doc) -> str: + """Renders template for a given doc (shard). + shard (Doc): Doc shard. Note that if the prompt is small enough to fit within the model's context window, + there will only be one shard, which is identical to the original doc. + RETURNS (str): Rendered template. + """ + return _template.render( + text=doc.text, + prompt_examples=self._prompt_examples, + **self._get_prompt_data(shard), + ) + for doc in self._preprocess_docs_for_prompt(docs): + # todo make prompt data a doc-dependent function (worry about EL after this works for available tasks) prompt = _template.render( text=doc.text, prompt_examples=self._prompt_examples, - **self._prompt_data, + **self._get_prompt_data(doc), ) yield prompt - @property - def _prompt_data(self) -> Dict[str, Any]: - """Returns data injected into prompt template. No-op if not overridden by inheriting task class. + def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data + returned by this might be static (i. e. the same for all doc shards) or dynamic (contingent on the doc shard). + shard (Doc): Doc (shard) for which prompt data should be fetched. RETURNS (Dict[str, Any]): Data injected into prompt template. """ return {} @@ -121,7 +134,6 @@ def get_cfg(self) -> Dict[str, Any]: def set_cfg(self, cfg: Dict[str, Any]) -> None: """Deserialize the task's configuration attributes. - cfg (Dict[str, Any]): dictionary containing configuration attributes. """ for key, value in cfg.items(): @@ -134,7 +146,6 @@ def _get_prompt_examples(self) -> List[Dict[str, Any]]: def _set_prompt_examples(self, examples: List[Dict[str, Any]]) -> None: """Set prompt examples. - examples (List[Dict[str, Any]]): prompt examples. """ self._prompt_examples = [ @@ -191,7 +202,6 @@ def to_disk( path (Path): A path (currently unused). exclude (Tuple): Names of properties to exclude from serialization. """ - serialize = { "cfg": lambda p: srsly.write_json(p, self.get_cfg()), "prompt_examples": lambda p: srsly.write_msgpack( @@ -252,7 +262,6 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, labels: List[str], @@ -265,7 +274,6 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. labels (List[str]): List of labels to pass to the template. @@ -281,7 +289,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, ) diff --git a/spacy_llm/tasks/lemma/registry.py b/spacy_llm/tasks/lemma/registry.py index fbaf488a..d4d555d3 100644 --- a/spacy_llm/tasks/lemma/registry.py +++ b/spacy_llm/tasks/lemma/registry.py @@ -1,9 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import ShardMapper, ShardReducer, TaskResponseParser -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_LEMMA_TEMPLATE_V1, LemmaTask from .util import LemmaExample, reduce_shards_to_doc, score @@ -30,7 +30,6 @@ def make_lemma_task( parse_responses: Optional[TaskResponseParser[LemmaTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, scorer: Optional[Scorer] = None, @@ -58,7 +57,6 @@ def make_lemma_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=lemma_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), scorer=scorer or score, diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index d7ff08a1..dad38ce5 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -5,8 +5,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,7 +19,6 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], prompt_examples: Optional[List[FewshotExample[Self]]], template: str, - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, scorer: Scorer, @@ -31,7 +29,6 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. @@ -41,7 +38,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, ) diff --git a/spacy_llm/tasks/ner/registry.py b/spacy_llm/tasks/ner/registry.py index dbe372ee..d4908904 100644 --- a/spacy_llm/tasks/ner/registry.py +++ b/spacy_llm/tasks/ner/registry.py @@ -2,13 +2,13 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import ShardMapper, ShardReducer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency, check_label_consistency_cot -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ..util.sharding import make_shard_mapper from .task import DEFAULT_NER_TEMPLATE_V1, DEFAULT_NER_TEMPLATE_V2 from .task import DEFAULT_NER_TEMPLATE_V3, NERTask, SpanTask from .util import NERCoTExample, NERExample, reduce_shards_to_doc, score @@ -58,7 +58,6 @@ def make_ner_task( labels=labels_list, template=DEFAULT_NER_TEMPLATE_V1, prompt_examples=span_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -121,7 +120,6 @@ def make_ner_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -142,7 +140,6 @@ def make_ner_task_v3( template: str = DEFAULT_NER_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, @@ -166,7 +163,6 @@ def make_ner_task_v3( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. @@ -188,7 +184,6 @@ def make_ner_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, diff --git a/spacy_llm/tasks/ner/task.py b/spacy_llm/tasks/ner/task.py index ec57bae6..1577c536 100644 --- a/spacy_llm/tasks/ner/task.py +++ b/spacy_llm/tasks/ner/task.py @@ -6,8 +6,7 @@ from spacy.util import filter_spans from ...compat import Literal, Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -26,7 +25,6 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], @@ -44,7 +42,6 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. label_definitions (Optional[Dict[str, str]]): Map of label -> description @@ -66,7 +63,6 @@ def __init__( template=template, parse_responses=parse_responses, prompt_example_type=prompt_example_type, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, label_definitions=label_definitions, diff --git a/spacy_llm/tasks/rel/registry.py b/spacy_llm/tasks/rel/registry.py index ef2dafe5..f7142255 100644 --- a/spacy_llm/tasks/rel/registry.py +++ b/spacy_llm/tasks/rel/registry.py @@ -1,10 +1,10 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper -from ...ty import ShardReducer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ..util.sharding import make_shard_mapper from .examples import RELExample from .parser import parse_responses_v1 from .task import DEFAULT_REL_TEMPLATE, RELTask @@ -24,7 +24,6 @@ def make_rel_task( prompt_example_type: Optional[Type[FewshotExample]] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, @@ -46,7 +45,6 @@ def make_rel_task( full examples, although both can be provided. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. @@ -64,7 +62,6 @@ def make_rel_task( template=template, label_definitions=label_definitions, prompt_examples=rel_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index 87393a6e..13fa5a9e 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -5,8 +5,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template from .util import EntityItem, RelationItem @@ -23,12 +22,22 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], verbose: bool, ): + super().__init__( + parse_responses=parse_responses, + prompt_example_type=prompt_example_type, + template=template, + prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, + labels=labels, + label_definitions=label_definitions, + normalizer=normalizer, + ) """Default REL task. Populates a `Doc._.rel` custom attribute. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. @@ -40,33 +49,20 @@ def __init__( of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than full examples, although both can be provided. - prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in + prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ - super().__init__( - parse_responses=parse_responses, - prompt_example_type=prompt_example_type, - template=template, - prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, - shard_mapper=shard_mapper, - shard_reducer=shard_reducer, - labels=labels, - label_definitions=label_definitions, - normalizer=normalizer, - ) self._verbose = verbose self._field = "rel" def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs] - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/sentiment/registry.py b/spacy_llm/tasks/sentiment/registry.py index 36b496e4..026d3369 100644 --- a/spacy_llm/tasks/sentiment/registry.py +++ b/spacy_llm/tasks/sentiment/registry.py @@ -1,9 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper -from ...ty import ShardReducer, TaskResponseParser -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SENTIMENT_TEMPLATE_V1, SentimentTask from .util import SentimentExample, reduce_shards_to_doc @@ -20,7 +20,6 @@ def make_sentiment_task( parse_responses: Optional[TaskResponseParser[SentimentTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, field: str = "sentiment", @@ -33,7 +32,6 @@ def make_sentiment_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. field (str): The name of the doc extension in which to store the summary. @@ -49,7 +47,6 @@ def make_sentiment_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=sentiment_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), field=field, diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index b75e6cc1..9ab1633d 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -4,8 +4,7 @@ from spacy.tokens import Doc from spacy.training import Example -from ...ty import FewshotExample, NTokenEstimator, Self, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, Self, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,7 +19,6 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, ): @@ -30,8 +28,8 @@ def __init__( parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. field (str): The name of the doc extension in which to store the sentiment score. - prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in + prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. """ @@ -40,7 +38,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, ) diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index c15ee884..c42ee373 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -5,8 +5,7 @@ from spacy.tokens import Doc, Span from ...compat import Literal, Protocol, Self -from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from . import SpanExample from .examples import SpanCoTExample @@ -34,7 +33,6 @@ def __init__( prompt_examples: Optional[ Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] ], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, description: Optional[str], @@ -50,7 +48,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, labels=labels, @@ -73,8 +70,7 @@ def __init__( if self._prompt_examples: self._prompt_examples = list(self._check_label_consistency(self)) - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return { "description": self._description, "labels": list(self._label_dict.values()), diff --git a/spacy_llm/tasks/spancat/registry.py b/spacy_llm/tasks/spancat/registry.py index 0d3c648f..f5aa7180 100644 --- a/spacy_llm/tasks/spancat/registry.py +++ b/spacy_llm/tasks/spancat/registry.py @@ -2,14 +2,14 @@ from ...compat import Literal from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import ShardMapper, ShardReducer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels from ..span import parse_responses as parse_span_responses from ..span import parse_responses_cot as parse_span_responses_cot from ..span.util import check_label_consistency as check_labels from ..span.util import check_label_consistency_cot as check_labels_cot -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ..util.sharding import make_shard_mapper from .task import DEFAULT_SPANCAT_TEMPLATE_V1, DEFAULT_SPANCAT_TEMPLATE_V2 from .task import DEFAULT_SPANCAT_TEMPLATE_V3, SpanCatTask from .util import SpanCatCoTExample, SpanCatExample, reduce_shards_to_doc, score @@ -62,7 +62,6 @@ def make_spancat_task( prompt_example_type=example_type, template=DEFAULT_SPANCAT_TEMPLATE_V1, prompt_examples=span_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -129,7 +128,6 @@ def make_spancat_task_v2( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -152,7 +150,6 @@ def make_spancat_task_v3( description: Optional[str] = None, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, @@ -177,7 +174,6 @@ def make_spancat_task_v3( full examples, although both can be provided. examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. @@ -200,7 +196,6 @@ def make_spancat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, diff --git a/spacy_llm/tasks/spancat/task.py b/spacy_llm/tasks/spancat/task.py index 08714a02..a7e9695d 100644 --- a/spacy_llm/tasks/spancat/task.py +++ b/spacy_llm/tasks/spancat/task.py @@ -5,8 +5,7 @@ from spacy.training import Example from ...compat import Literal, Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..span import SpanTask from ..span.task import SpanTaskLabelCheck from ..templates import read_template @@ -26,7 +25,6 @@ def __init__( label_definitions: Optional[Dict[str, str]], spans_key: str, prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], @@ -50,7 +48,6 @@ def __init__( full examples, although both can be provided. spans_key (str): Key of the `Doc.spans` dict to save under. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. @@ -69,7 +66,6 @@ def __init__( template=template, label_definitions=label_definitions, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, normalizer=normalizer, diff --git a/spacy_llm/tasks/summarization/registry.py b/spacy_llm/tasks/summarization/registry.py index a46fbb21..083a7363 100644 --- a/spacy_llm/tasks/summarization/registry.py +++ b/spacy_llm/tasks/summarization/registry.py @@ -1,9 +1,9 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, ShardMapper -from ...ty import ShardReducer, TaskResponseParser -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1 from .task import DEFAULT_SUMMARIZATION_TEMPLATE_V1, SummarizationTask from .util import SummarizationExample, reduce_shards_to_doc @@ -20,7 +20,6 @@ def make_summarization_task( parse_responses: Optional[TaskResponseParser[SummarizationTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, max_n_words: Optional[int] = None, @@ -34,7 +33,6 @@ def make_summarization_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. max_n_words (int): Max. number of words to use in summary. @@ -51,7 +49,6 @@ def make_summarization_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=span_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), max_n_words=max_n_words, diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 0d233c35..6b730fc9 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -6,8 +6,7 @@ from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template @@ -20,7 +19,6 @@ def __init__( parse_responses: TaskResponseParser[Self], prompt_example_type: Type[FewshotExample[Self]], template: str, - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, max_n_words: Optional[int], @@ -32,7 +30,6 @@ def __init__( template (str): Prompt template passed to the model. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. max_n_words (Optional[int]): Max. number of words to use in summary. @@ -44,7 +41,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, ) @@ -88,11 +84,7 @@ def _check_prompt_example_summary_len(self) -> None: f"LLM will likely produce responses that are too long." ) - @property - def _prompt_data(self) -> Dict[str, Any]: - """Returns data injected into prompt template. No-op if not overridden by inheriting task class. - RETURNS (Dict[str, Any]): Data injected into prompt template. - """ + def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: if self._check_example_summaries: self._check_prompt_example_summary_len() self._check_example_summaries = False diff --git a/spacy_llm/tasks/textcat/registry.py b/spacy_llm/tasks/textcat/registry.py index 4f9921af..67885025 100644 --- a/spacy_llm/tasks/textcat/registry.py +++ b/spacy_llm/tasks/textcat/registry.py @@ -1,10 +1,10 @@ from typing import Callable, Dict, List, Optional, Type, Union from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, NTokenEstimator, Scorer -from ...ty import ShardMapper, ShardReducer, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, ShardMapper, ShardReducer +from ...ty import TaskResponseParser from ...util import split_labels -from ..util.sharding import make_n_token_estimator, make_shard_mapper +from ..util.sharding import make_shard_mapper from .parser import parse_responses_v1_v2_v3 from .task import DEFAULT_TEXTCAT_TEMPLATE_V1, DEFAULT_TEXTCAT_TEMPLATE_V2 from .task import DEFAULT_TEXTCAT_TEMPLATE_V3, TextCatTask @@ -69,7 +69,6 @@ def make_textcat_task( labels=labels_list, template=DEFAULT_TEXTCAT_TEMPLATE_V1, prompt_examples=textcat_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -138,7 +137,6 @@ def make_textcat_task_v2( labels=labels_list, template=template, prompt_examples=textcat_examples, - n_token_estimator=make_n_token_estimator(), shard_mapper=make_shard_mapper(), shard_reducer=make_shard_reducer(), normalizer=normalizer, @@ -158,7 +156,6 @@ def make_textcat_task_v3( template: str = DEFAULT_TEXTCAT_TEMPLATE_V3, label_definitions: Optional[Dict[str, str]] = None, examples: ExamplesConfigType = None, - n_token_estimator: Optional[NTokenEstimator] = None, shard_mapper: Optional[ShardMapper] = None, shard_reducer: Optional[ShardReducer] = None, normalizer: Optional[Callable[[str], str]] = None, @@ -193,7 +190,6 @@ def make_textcat_task_v3( These descriptions are added to the prompt to help instruct the LLM on what to extract. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. - n_token_estimator (Optional[NTokenEstimator]): Estimates number of tokens in a string. shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. @@ -218,7 +214,6 @@ def make_textcat_task_v3( template=template, label_definitions=label_definitions, prompt_examples=textcat_examples, - n_token_estimator=n_token_estimator or make_n_token_estimator(), shard_mapper=shard_mapper or make_shard_mapper(), shard_reducer=shard_reducer or make_shard_reducer(), normalizer=normalizer, diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 40c81692..a4af9b78 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -6,8 +6,7 @@ from wasabi import msg from ...compat import Self -from ...ty import FewshotExample, NTokenEstimator, Scorer, ShardMapper, ShardReducer -from ...ty import TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template @@ -25,7 +24,6 @@ def __init__( template: str, label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], - n_token_estimator: NTokenEstimator, shard_mapper: ShardMapper, shard_reducer: ShardReducer, normalizer: Optional[Callable[[str], str]], @@ -57,7 +55,6 @@ def __init__( label_definitions (Optional[Dict[str, str]]): Optional dict mapping a label to a description of that label. These descriptions are added to the prompt to help instruct the LLM on what to extract. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. - n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. @@ -72,7 +69,6 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, - n_token_estimator=n_token_estimator, shard_mapper=shard_mapper, shard_reducer=shard_reducer, labels=labels, @@ -93,8 +89,7 @@ def __init__( ) self._exclusive_classes = True - @property - def _prompt_data(self) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index f709cb26..7709b970 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, List, Optional from spacy.tokens import Doc @@ -20,15 +20,67 @@ def count_tokens_by_spaces(value: str) -> int: @registry.llm_misc("spacy.ShardMapper.v1") -def make_shard_mapper() -> ShardMapper: +def make_shard_mapper( + n_token_estimator: NTokenEstimator = make_n_token_estimator(), + buffer_frac: float = 1.1, +) -> ShardMapper: """Generates Callable mapping doc to doc shards fitting within context length. + n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. + buffer_frac (float): Buffer to consider in assessment of whether prompt fits into context. E. g. if value is 1.1, + prompt length * 1.1 will be compared with the context length. + # todo sharding would be better with sentences instead of tokens, but this requires some form of sentence + # splitting we can't rely one...maybe checking for sentences and/or as optional arg? RETURNS (ShardMapper): Callable mapping doc to doc shards fitting within context length. """ def map_doc_to_shards( - doc: Doc, context_length: int, render_template: Callable[[str], str] + doc: Doc, context_length: int, render_template: Callable[[Doc], str] ): - # todo this is yet a dummy implementation that will fail for texts with len(text) > context length. - return [doc] + prompt = render_template(doc) + + # If prompt with complete doc too long: split in shards. + if n_token_estimator(prompt) * buffer_frac > context_length: + shards: List[Doc] = [] + # Prompt length unfortunately can't be exacted computed prior to rendering the prompt, as external + # information not present in the doc (e. g. entity description for EL prompts) may be injected. + # For this reason we follow a greedy binary search heuristic, if the fully rendered prompt is too long: + # 1. Get total number of tokens/sentences (depending on the reducer's configuration) + # 2. Splice off doc up to the first half of tokens/sentences + # 3. Render prompt and check whether it fits into context + # 4. If yes: repeat with second doc half. + # 5. If not: repeat from 2., but with split off shard instead of doc. + remaining_doc: Optional[Doc] = doc.copy() + fraction = 0.5 + start_idx = 0 + + while remaining_doc is not None: + fits_in_context = False + shard: Optional[Doc] = None + end_idx = -1 + + while fits_in_context is False: + end_idx = start_idx + int(len(remaining_doc) * fraction) + shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) + fits_in_context = ( + n_token_estimator(render_template(shard)) * buffer_frac + <= context_length + ) + fraction /= 2 + + assert shard is not None + shards.append(shard) + fraction = 1 + start_idx = end_idx + # Set remaining_doc to None if shard contains all of it, i. e. entire original doc has been processed. + remaining_doc = ( + doc[end_idx:].as_doc(copy_user_data=True) + if shard.text != remaining_doc.text + else None + ) + + return shards + + else: + return [doc] return map_doc_to_shards diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 73013f88..e75a3d18 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -24,7 +24,7 @@ NTokenEstimator = Callable[[str], int] ShardMapper = Callable[ # Requires doc, context length and callable for rendering template from doc shard text. - [Doc, int, Callable[[str], str]], + [Doc, int, Callable[[Doc], str]], # Returns each shard as a doc. Iterable[Doc], ] From 89a55106f5ae27c3113ef6a6da9453446644bfe0 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 19 Oct 2023 12:53:34 +0200 Subject: [PATCH 05/51] Integrating mapping into prompt generation workflow. --- spacy_llm/pipeline/llm.py | 12 +++++++++--- spacy_llm/tasks/builtin_task.py | 22 +++++++++++++++------- spacy_llm/tasks/noop.py | 6 ++++-- spacy_llm/tasks/util/sharding.py | 13 +++++++------ spacy_llm/tests/models/test_rest.py | 6 ++++-- spacy_llm/tests/pipeline/test_llm.py | 6 ++++-- spacy_llm/tests/test_cache.py | 6 ++++-- spacy_llm/ty.py | 17 +++++++++++++++-- 8 files changed, 62 insertions(+), 26 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 1c3365d4..2b84912f 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -15,8 +15,9 @@ from .. import registry # noqa: F401 from ..compat import TypedDict -from ..ty import Cache, LabeledTask, LLMTask, PromptExecutorType, ScorableTask -from ..ty import Serializable, validate_type_consistency +from ..ty import Cache, LabeledTask, LLMTask, ModelWithContextLength +from ..ty import PromptExecutorType, ScorableTask, Serializable +from ..ty import validate_type_consistency logger = logging.getLogger("spacy_llm") logger.addHandler(logging.NullHandler()) @@ -204,8 +205,13 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: modified_docs: Iterator[Doc] = iter(()) if len(noncached_doc_batch) > 0: n_iters = 3 if self._save_io else 2 + context_length: Optional[int] = None + if isinstance(self._model, ModelWithContextLength): + context_length = self._model.context_length + prompts_iters = tee( - self._task.generate_prompts(noncached_doc_batch), n_iters + self._task.generate_prompts(noncached_doc_batch, context_length), + n_iters, ) responses_iters = tee(self._model(prompts_iters[0]), n_iters) for prompt, response, doc in zip( diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index bdd773bc..a1c848cb 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -52,9 +52,13 @@ def __init__( self._shard_mapper = shard_mapper self._shard_reducer = shard_reducer - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[Any]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. + ontext_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, + if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. RETURNS (Iterable[Any]): Iterable with one prompt per doc. """ environment = jinja2.Environment() @@ -73,13 +77,17 @@ def render_template(shard: Doc) -> str: ) for doc in self._preprocess_docs_for_prompt(docs): - # todo make prompt data a doc-dependent function (worry about EL after this works for available tasks) - prompt = _template.render( - text=doc.text, - prompt_examples=self._prompt_examples, - **self._get_prompt_data(doc), + # If no context length provided (e. g. because models don't provide it): don't shard. + shards = ( + self._shard_mapper(doc, context_length, render_template) + if context_length is not None + else [doc] ) - yield prompt + prompts = [ + render_template(shard) + for shard in (shards if isinstance(shards, list) else [shards]) + ] + yield prompts if len(prompts) > 1 else prompts[0] def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data diff --git a/spacy_llm/tasks/noop.py b/spacy_llm/tasks/noop.py index 044ca5ac..dc31ce40 100644 --- a/spacy_llm/tasks/noop.py +++ b/spacy_llm/tasks/noop.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, Optional from spacy.tokens import Doc @@ -13,7 +13,9 @@ def make_noop_task(): class NoopTask: - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[str]: for _ in docs: yield _NOOP_PROMPT diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index 7709b970..78d60d1e 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Optional +from typing import Callable, Iterable, List, Optional, Union from spacy.tokens import Doc @@ -21,7 +21,7 @@ def count_tokens_by_spaces(value: str) -> int: @registry.llm_misc("spacy.ShardMapper.v1") def make_shard_mapper( - n_token_estimator: NTokenEstimator = make_n_token_estimator(), + n_token_estimator: Optional[NTokenEstimator] = None, buffer_frac: float = 1.1, ) -> ShardMapper: """Generates Callable mapping doc to doc shards fitting within context length. @@ -32,14 +32,15 @@ def make_shard_mapper( # splitting we can't rely one...maybe checking for sentences and/or as optional arg? RETURNS (ShardMapper): Callable mapping doc to doc shards fitting within context length. """ + n_tok_est: NTokenEstimator = n_token_estimator or make_n_token_estimator() def map_doc_to_shards( doc: Doc, context_length: int, render_template: Callable[[Doc], str] - ): + ) -> Union[Iterable[Doc], Doc]: prompt = render_template(doc) # If prompt with complete doc too long: split in shards. - if n_token_estimator(prompt) * buffer_frac > context_length: + if n_tok_est(prompt) * buffer_frac > context_length: shards: List[Doc] = [] # Prompt length unfortunately can't be exacted computed prior to rendering the prompt, as external # information not present in the doc (e. g. entity description for EL prompts) may be injected. @@ -62,7 +63,7 @@ def map_doc_to_shards( end_idx = start_idx + int(len(remaining_doc) * fraction) shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) fits_in_context = ( - n_token_estimator(render_template(shard)) * buffer_frac + n_tok_est(render_template(shard)) * buffer_frac <= context_length ) fraction /= 2 @@ -81,6 +82,6 @@ def map_doc_to_shards( return shards else: - return [doc] + return doc return map_doc_to_shards diff --git a/spacy_llm/tests/models/test_rest.py b/spacy_llm/tests/models/test_rest.py index 035cba5c..08044e3b 100644 --- a/spacy_llm/tests/models/test_rest.py +++ b/spacy_llm/tests/models/test_rest.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import copy import re -from typing import Iterable +from typing import Iterable, Optional import pytest import spacy @@ -22,7 +22,9 @@ class _CountTask: _PROMPT_TEMPLATE = "Count the number of characters in this string: '{text}'." - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[str]: for doc in docs: yield _CountTask._PROMPT_TEMPLATE.format(text=doc.text) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 6ef297b1..72fb9bcc 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -2,7 +2,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Optional import pytest import spacy @@ -156,7 +156,9 @@ class NoopTask_Incorrect: def __init__(self): pass - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[int]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[int]: return [0] * len(list(docs)) def parse_responses( diff --git a/spacy_llm/tests/test_cache.py b/spacy_llm/tests/test_cache.py index 1522c82c..0cc703d5 100644 --- a/spacy_llm/tests/test_cache.py +++ b/spacy_llm/tests/test_cache.py @@ -3,7 +3,7 @@ import re import time from pathlib import Path -from typing import Dict, Iterable +from typing import Dict, Iterable, Optional import pytest import spacy @@ -211,7 +211,9 @@ def test_prompt_template_handling(): @registry.llm_tasks("NoPromptTemplate.v1") class NoopTask_NoPromptTemplate: - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[str]: return [""] * len(list(docs)) def parse_responses( diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index e75a3d18..9de8093b 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -26,7 +26,7 @@ # Requires doc, context length and callable for rendering template from doc shard text. [Doc, int, Callable[[Doc], str]], # Returns each shard as a doc. - Iterable[Doc], + Union[Iterable[Doc], Doc], ] ShardReducer = Callable[[Iterable[Doc]], Doc] @@ -93,9 +93,13 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: @runtime_checkable class LLMTask(Protocol): - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]: + def generate_prompts( + self, docs: Iterable[Doc], context_length: Optional[int] = None + ) -> Iterable[_PromptType]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. + context_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, + if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc. """ @@ -191,6 +195,15 @@ def __getitem__(self, doc: Doc) -> Optional[Doc]: """ +@runtime_checkable +class ModelWithContextLength(Protocol): + @property + def context_length(self) -> int: + """Provides context length for the corresponding model. + RETURNS (int): Context length for the corresponding model. + """ + + def _do_args_match(out_arg: Iterable, in_arg: Iterable) -> bool: """Compares argument type of Iterables for compatibility. in_arg (Iterable): Input argument. From 086dec949ae451012f2afc5115b9587e3f5e8dd6 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 20 Oct 2023 14:19:16 +0200 Subject: [PATCH 06/51] Update response parsing and component to support sharding (WIP). --- spacy_llm/models/hf/base.py | 6 +- spacy_llm/models/hf/dolly.py | 14 +- spacy_llm/models/hf/falcon.py | 13 +- spacy_llm/models/hf/llama2.py | 13 +- spacy_llm/models/hf/mistral.py | 55 ++++---- spacy_llm/models/hf/openllama.py | 44 ++++--- spacy_llm/models/hf/stablelm.py | 61 +++++---- spacy_llm/models/langchain/model.py | 40 +++--- spacy_llm/models/rest/anthropic/model.py | 104 ++++++++------- spacy_llm/models/rest/anthropic/registry.py | 14 +- spacy_llm/models/rest/azure/model.py | 132 ++++++++++--------- spacy_llm/models/rest/azure/registry.py | 2 +- spacy_llm/models/rest/base.py | 6 +- spacy_llm/models/rest/cohere/model.py | 132 ++++++++++--------- spacy_llm/models/rest/cohere/registry.py | 2 +- spacy_llm/models/rest/noop/model.py | 2 +- spacy_llm/models/rest/noop/registry.py | 2 +- spacy_llm/models/rest/openai/model.py | 121 +++++++++-------- spacy_llm/models/rest/openai/registry.py | 44 +++---- spacy_llm/models/rest/palm/model.py | 123 +++++++++-------- spacy_llm/models/rest/palm/registry.py | 2 +- spacy_llm/pipeline/llm.py | 28 +++- spacy_llm/tasks/builtin_task.py | 18 ++- spacy_llm/tasks/lemma/parser.py | 30 +++-- spacy_llm/tasks/lemma/task.py | 31 +++-- spacy_llm/tasks/noop.py | 14 +- spacy_llm/tasks/rel/parser.py | 46 ++++--- spacy_llm/tasks/rel/task.py | 14 +- spacy_llm/tasks/sentiment/parser.py | 30 +++-- spacy_llm/tasks/sentiment/task.py | 22 ++-- spacy_llm/tasks/span/parser.py | 75 ++++++----- spacy_llm/tasks/span/task.py | 14 +- spacy_llm/tasks/summarization/parser.py | 22 ++-- spacy_llm/tasks/summarization/task.py | 15 ++- spacy_llm/tasks/textcat/parser.py | 77 ++++++----- spacy_llm/tasks/textcat/task.py | 15 ++- spacy_llm/tasks/util/sharding.py | 2 +- spacy_llm/tests/conftest.py | 4 +- spacy_llm/tests/models/test_rest.py | 8 +- spacy_llm/tests/pipeline/test_llm.py | 7 +- spacy_llm/tests/test_cache.py | 9 +- spacy_llm/tests/test_combinations.py | 2 +- spacy_llm/ty.py | 60 ++++++--- usage_examples/tests/test_readme_examples.py | 10 +- 44 files changed, 847 insertions(+), 638 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 72ce7446..8eaf39ee 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -39,10 +39,10 @@ def __init__( self._model = self.init_model() @abc.abstractmethod - def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]: + def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: """Executes prompts on specified API. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): API responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute per doc. + RETURNS (Iterable[Iterable[Any]]): API responses per doc. """ def _check_model(self) -> None: diff --git a/spacy_llm/models/hf/dolly.py b/spacy_llm/models/hf/dolly.py index a1d658fa..fc15dc67 100644 --- a/spacy_llm/models/hf/dolly.py +++ b/spacy_llm/models/hf/dolly.py @@ -18,14 +18,18 @@ def init_model(self) -> Any: model=self._name, return_full_text=False, **self._config_init ) - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] """Queries Dolly HF model. pipeline (transformers.pipeline): Transformers pipeline to query. - prompts (Iterable[str]): Prompts to query Dolly model with. - RETURNS (Iterable[str]): Prompt responses. + prompts (Iterable[Iterable[str]]): Prompts per doc to query Dolly model with. + RETURNS (Iterable[Iterable[str]]): Prompt responses per doc. """ return [ - self._model(pr, **self._config_run)[0]["generated_text"] for pr in prompts + [ + self._model(pr, **self._config_run)[0]["generated_text"] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @property @@ -56,7 +60,7 @@ def dolly_hf( name: Dolly.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Dolly instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Dolly model. Has to be one of Dolly.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 64c33ea5..5e15828a 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -46,10 +46,15 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "tiiuae" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] return [ - self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"] - for pr in prompts + [ + self._model(pr, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @staticmethod @@ -73,7 +78,7 @@ def falcon_hf( name: Falcon.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Falcon instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/hf/llama2.py b/spacy_llm/models/hf/llama2.py index 8fc8eef6..ab5e1063 100644 --- a/spacy_llm/models/hf/llama2.py +++ b/spacy_llm/models/hf/llama2.py @@ -39,10 +39,15 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "meta-llama" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] return [ - self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"] - for pr in prompts + [ + self._model(pr, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + for pr in prompts_for_doc + ] + for prompts_for_doc in prompts ] @staticmethod @@ -59,7 +64,7 @@ def llama2_hf( name: Llama2.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Llama 2 instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Llama 2 model. Has to be one of Llama2.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index cd8bbd9b..082b915e 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from confection import SimpleFrozenDict @@ -48,31 +48,40 @@ def init_model(self) -> Any: def hf_account(self) -> str: return "mistralai" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) assert hasattr(self._model, "generate") assert hasattr(self._tokenizer, "batch_decode") - prompts = list(prompts) - - tokenized_input_ids = [ - self._tokenizer( - prompt if not self._is_instruct else f"[INST] {prompt} [/INST]", - return_tensors="pt", - ).input_ids - for prompt in prompts - ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] - - return [ - self._tokenizer.decode( - self._model.generate( - input_ids=tok_ii, generation_config=self._hf_config_run - )[:, tok_ii.shape[1] :][0], - skip_special_tokens=True, + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + prompts_for_doc = list(prompts_for_doc) + + tokenized_input_ids = [ + self._tokenizer( + prompt if not self._is_instruct else f"[INST] {prompt} [/INST]", + return_tensors="pt", + ).input_ids + for prompt in prompts_for_doc + ] + if self._device: + tokenized_input_ids = [ + tp.to(self._device) for tp in tokenized_input_ids + ] + + responses.append( + [ + self._tokenizer.decode( + self._model.generate( + input_ids=tok_ii, generation_config=self._hf_config_run + )[:, tok_ii.shape[1] :][0], + skip_special_tokens=True, + ) + for tok_ii in tokenized_input_ids + ] ) - for tok_ii in tokenized_input_ids - ] + + return responses @staticmethod def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -92,7 +101,7 @@ def mistral_hf( name: Mistral.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates Mistral instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index f7e1ff83..3f43813e 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from confection import SimpleFrozenDict @@ -43,23 +43,33 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": return model - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) - tokenized_input_ids = [ - self._tokenizer(prompt, return_tensors="pt").input_ids for prompt in prompts - ] - if self._device: - tokenized_input_ids = [tii.to(self._device) for tii in tokenized_input_ids] - - assert hasattr(self._model, "generate") - return [ - self._tokenizer.decode( - self._model.generate(input_ids=tii, **self._config_run)[ - :, tii.shape[1] : - ][0], + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + tokenized_input_ids = [ + self._tokenizer(prompt, return_tensors="pt").input_ids + for prompt in prompts_for_doc + ] + if self._device: + tokenized_input_ids = [ + tii.to(self._device) for tii in tokenized_input_ids + ] + + assert hasattr(self._model, "generate") + responses.append( + [ + self._tokenizer.decode( + self._model.generate(input_ids=tii, **self._config_run)[ + :, tii.shape[1] : + ][0], + ) + for tii in tokenized_input_ids + ] ) - for tii in tokenized_input_ids - ] + + return responses @property def hf_account(self) -> str: @@ -86,7 +96,7 @@ def openllama_hf( name: OpenLLaMA.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates OpenLLaMA instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the OpenLLaMA model. Has to be one of OpenLLaMA.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 028e81e7..c352dc00 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple from confection import SimpleFrozenDict @@ -66,33 +66,42 @@ def init_model(self) -> "transformers.AutoModelForCausalLM": def hf_account(self) -> str: return "stabilityai" - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override] + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] assert callable(self._tokenizer) - tokenized_input_ids = [ - self._tokenizer(prompt, return_tensors="pt").input_ids - for prompt in ( - # Add prompt formatting for tuned model. - prompts - if not self._is_tuned - else [ - f"{StableLM._SYSTEM_PROMPT}<|USER|>{prompt}<|ASSISTANT|>" - for prompt in prompts + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + tokenized_input_ids = [ + self._tokenizer(prompt, return_tensors="pt").input_ids + for prompt in ( + # Add prompt formatting for tuned model. + prompts_for_doc + if not self._is_tuned + else [ + f"{StableLM._SYSTEM_PROMPT}<|USER|>{prompt}<|ASSISTANT|>" + for prompt in prompts_for_doc + ] + ) + ] + if self._device: + tokenized_input_ids = [ + tp.to(self._device) for tp in tokenized_input_ids + ] + + assert hasattr(self._model, "generate") + responses.append( + [ + self._tokenizer.decode( + self._model.generate(input_ids=tii, **self._config_run)[ + :, tii.shape[1] : + ][0], + skip_special_tokens=True, + ) + for tii in tokenized_input_ids ] ) - ] - if self._device: - tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids] - - assert hasattr(self._model, "generate") - return [ - self._tokenizer.decode( - self._model.generate(input_ids=tii, **self._config_run)[ - :, tii.shape[1] : - ][0], - skip_special_tokens=True, - ) - for tii in tokenized_input_ids - ] + + return responses @staticmethod def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -117,7 +126,7 @@ def stablelm_hf( name: StableLM.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Generates StableLM instance that can execute a set of prompts and return the raw responses. name (Literal): Name of the StableLM model. Has to be one of StableLM.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 6c491fd3..ff162b36 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -19,15 +19,15 @@ def __init__( api: str, config: Dict[Any, Any], query: Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[Any]], - Iterable[Any], + ["langchain.base_language.BaseLanguageModel", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ], ): """Initializes model instance for integration APIs. name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[Any, Iterable[_PromptType]], Iterable[_ResponseType]]): Callable executing LLM prompts when + query (Callable[[Any, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing LLM prompts when supplied with the `integration` object. """ self._langchain_model = LangChain.get_type_to_cls_dict()[api]( @@ -45,23 +45,25 @@ def get_type_to_cls_dict() -> Dict[ """ return langchain.llms.type_to_cls_dict - def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]: + def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: """Executes prompts on specified API. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): API responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute. + RETURNS (Iterable[Iterable[Any]]): API responses. """ return self.query(self._langchain_model, prompts) @staticmethod def query_langchain( - model: "langchain.base_language.BaseLanguageModel", prompts: Iterable[Any] - ) -> Iterable[Any]: + model: "langchain.base_language.BaseLanguageModel", + prompts: Iterable[Iterable[Any]], + ) -> Iterable[Iterable[Any]]: """Query LangChain model naively. model (langchain.base_language.BaseLanguageModel): LangChain model. - prompts (Iterable[Any]): Prompts to execute. - RETURNS (Iterable[Any]): LLM responses. + prompts (Iterable[Iterable[Any]]): Prompts to execute. + RETURNS (Iterable[Iterable[Any]]): LLM responses. """ - return [model(pr) for pr in prompts] + assert callable(model) + return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts] @staticmethod def _check_installation() -> None: @@ -78,13 +80,16 @@ def langchain_model( name: str, query: Optional[ Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[str]], - Iterable[str], + [ + "langchain.base_language.BaseLanguageModel", + Iterable[Iterable[str]], + ], + Iterable[Iterable[str]], ] ] = None, config: Dict[Any, Any] = SimpleFrozenDict(), langchain_class_id: str = class_id, - ) -> Optional[Callable[[Iterable[Any]], Iterable[Any]]]: + ) -> Optional[Callable[[Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]]: try: return LangChain( name=name, @@ -124,11 +129,12 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[Any]], Iterable[Any] + ["langchain.base_language.BaseLanguageModel", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ] ): """Returns query Callable for LangChain. - RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on - the specified LangChain model. + RETURNS (Callable[["langchain.base_language.BaseLanguageModel", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): + Callable executing simple prompts on the specified LangChain model. """ return LangChain.query_langchain diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index efc7106b..06013cb4 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -50,63 +50,71 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "model": self._name, "anthropic-version": self._config.get("anthropic-version", "2023-06-01"), "Content-Type": "application/json", } + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, + ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + error = res_content.get("error", {}) + error_msg = f"Request to Anthropic API failed: {error}" + if error["type"] == "not_found_error": + error_msg += f". Ensure that the selected model ({self._name}) is supported by the API." + raise ValueError(error_msg) from ex + response = r.json() + + # c.f. https://console.anthropic.com/docs/api/errors + if "error" in response: + if self._strict: + raise ValueError(f"API call failed: {response}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + return response + + # Anthropic API currently doesn't accept batch prompts, so we're making + # a request for each iteration. This approach can be prone to rate limit + # errors. In practice, you can adjust _max_request_time so that the + # timeout is larger. + responses = [ + _request( + {"prompt": f"{SystemPrompt.HUMAN} {prompt}{SystemPrompt.ASST}"} + ) + for prompt in prompts_for_doc + ] - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self._endpoint, - headers=headers, - json={**json_data, **self._config, "model": self._name}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - error = res_content.get("error", {}) - error_msg = f"Request to Anthropic API failed: {error}" - if error["type"] == "not_found_error": - error_msg += f". Ensure that the selected model ({self._name}) is supported by the API." - raise ValueError(error_msg) from ex - response = r.json() - - # c.f. https://console.anthropic.com/docs/api/errors - if "error" in response: - if self._strict: - raise ValueError(f"API call failed: {response}.") + for response in responses: + if "completion" in response: + api_responses.append(response["completion"]) else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} - return response - - # Anthropic API currently doesn't accept batch prompts, so we're making - # a request for each iteration. This approach can be prone to rate limit - # errors. In practice, you can adjust _max_request_time so that the - # timeout is larger. - responses = [ - _request({"prompt": f"{SystemPrompt.HUMAN} {prompt}{SystemPrompt.ASST}"}) - for prompt in prompts - ] - - for response in responses: - if "completion" in response: - api_responses.append(response["completion"]) - else: - api_responses.append(srsly.json_dumps(response)) + api_responses.append(srsly.json_dumps(response)) + + assert len(api_responses) == len(prompts_for_doc) + all_api_responses.append(api_responses) - assert len(api_responses) == len(prompts) - return api_responses + return all_api_responses @staticmethod def _get_context_lengths() -> Dict[str, int]: diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index 504da15a..89c9157a 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -15,7 +15,7 @@ def anthropic_claude_2( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-2' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-2", "claude-2-100k"]): Model to use. @@ -49,7 +49,7 @@ def anthropic_claude_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1", "claude-1-100k"]): Model to use. @@ -85,7 +85,7 @@ def anthropic_claude_instant_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-instant-1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-instant-1", "claude-instant-1-100k"]): Model to use. @@ -121,7 +121,7 @@ def anthropic_claude_instant_1_1( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-instant-1.1' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-instant-1.1", "claude-instant-1.1-100k"]): Model to use. @@ -155,7 +155,7 @@ def anthropic_claude_1_0( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.0' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.0"]): Model to use. @@ -189,7 +189,7 @@ def anthropic_claude_1_2( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.2' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.2"]): Model to use. @@ -223,7 +223,7 @@ def anthropic_claude_1_3( max_tries: int = Anthropic.DEFAULT_MAX_TRIES, interval: float = Anthropic.DEFAULT_INTERVAL, max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Anthropic instance for 'claude-1.3' model using REST to prompt API. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. name (Literal["claude-1.3", "claude-1.3-100k"]): Model variant to use. diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index 96617f68..c7a5860d 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -76,77 +76,87 @@ def _verify_auth(self) -> None: except ValueError as err: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self.endpoint, - headers=headers, - json={**json_data, **self._config}, - timeout=self._max_request_time, - params={"api-version": self._api_version}, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - raise ValueError( - f"Request to Azure OpenAI API failed: " - f"{res_content.get('error', {}).get('message', str(res_content))}" - ) from ex - responses = r.json() - - if "error" in responses: - if self._strict: - raise ValueError(f"API call failed: {responses}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(responses)] * len(prompts)} - - return responses - - # The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests. - # https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai - - if self._model_type == ModelType.CHAT: - # Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model - # yet. - for prompt in prompts: - responses = _request( - {"messages": [{"role": "user", "content": prompt}]} + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self.endpoint, + headers=headers, + json={**json_data, **self._config}, + timeout=self._max_request_time, + params={"api-version": self._api_version}, ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + raise ValueError( + f"Request to Azure OpenAI API failed: " + f"{res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() + if "error" in responses: - return responses["error"] - - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append( - response.get("message", {}).get( - "content", srsly.json_dumps(response) + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + # The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests. + # https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai + + if self._model_type == ModelType.CHAT: + # Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model + # yet. + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} + ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) ) - ) - elif self._model_type == ModelType.COMPLETION: - for prompt in prompts: - responses = _request({"prompt": prompt}) - if "error" in responses: - return responses["error"] + elif self._model_type == ModelType.COMPLETION: + for prompt in prompts_for_doc: + responses = _request({"prompt": prompt}) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("text", srsly.json_dumps(response)) + ) - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append(response.get("text", srsly.json_dumps(response))) + all_api_responses.append(api_responses) - return api_responses + return all_api_responses @staticmethod def _get_context_lengths() -> Dict[str, int]: diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index 3493e1b4..d69117f8 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -20,7 +20,7 @@ def azure_openai( interval: float = AzureOpenAI.DEFAULT_INTERVAL, max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, api_version: str = "2023-05-15", -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. Docs on OpenAI models supported by Azure: diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index e89d3928..16f49bd0 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -72,10 +72,10 @@ def _check_model(self) -> None: ) @abc.abstractmethod - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: """Executes prompts on specified API. - prompts (Iterable[str]): Prompts to execute. - RETURNS (Iterable[str]): API responses. + prompts (Iterable[Iterable[str]]): Prompts to execute. + RETURNS (Iterable[Iterable[str]]): API responses. """ @classmethod diff --git a/spacy_llm/models/rest/cohere/model.py b/spacy_llm/models/rest/cohere/model.py index 2973945c..202d55c4 100644 --- a/spacy_llm/models/rest/cohere/model.py +++ b/spacy_llm/models/rest/cohere/model.py @@ -39,78 +39,86 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", "Accept": "application/json", } + all_api_responses: List[List[str]] = [] - api_responses: List[str] = [] - prompts = list(prompts) + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self._endpoint, - headers=headers, - json={**json_data, **self._config}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - error_message = res_content.get("message", {}) - # Catch 'blocked output' and 'blocked input' errors from Cohere - # This usually happens when it detects violations in their Usage guidelines. - # Unfortunately Cohere returns this as an HTTPError, so it cannot be caught in the response. - if "blocked" in error_message: - # Only raise an error when strict. If strict is False, do - # nothing and parse the response as usual. - if self._strict: + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config}, + timeout=self._max_request_time, + ) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + error_message = res_content.get("message", {}) + # Catch 'blocked output' and 'blocked input' errors from Cohere + # This usually happens when it detects violations in their Usage guidelines. + # Unfortunately Cohere returns this as an HTTPError, so it cannot be caught in the response. + if "blocked" in error_message: + # Only raise an error when strict. If strict is False, do + # nothing and parse the response as usual. + if self._strict: + raise ValueError( + f"Cohere API returned a blocking error. {error_message}. " + "If you wish to ignore and continue, you can pass 'False' to the 'strict' argument of " + "this model. However, note that this will affect how spacy-llm parses the response." + ) from ex + else: + # Catching other types of HTTPErrors (e.g., "429: too many requests") raise ValueError( - f"Cohere API returned a blocking error. {error_message}. " - "If you wish to ignore and continue, you can pass 'False' to the 'strict' argument of this model. " - "However, note that this will affect how spacy-llm parses the response." + f"Request to Cohere API failed: {error_message}" ) from ex - else: - # Catching other types of HTTPErrors (e.g., "429: too many requests") - raise ValueError( - f"Request to Cohere API failed: {error_message}" - ) from ex - response = r.json() - - # Cohere returns a 'message' key when there is an error - # in the response. - if "message" in response: - if self._strict: - raise ValueError(f"API call failed: {response}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} - return response - - # Cohere API currently doesn't accept batch prompts, so we're making - # a request for each iteration. This approach can be prone to rate limit - # errors. In practice, you can adjust _max_request_time so that the - # timeout is larger. - responses = [_request({"prompt": prompt}) for prompt in prompts] - for response in responses: - if "generations" in response: - for result in response["generations"]: - if "text" in result: - # Although you can set the number of completions in Cohere - # to be greater than 1, we only need to return a single value. - # In this case, we will just return the very first output. - api_responses.append(result["text"]) - break + response = r.json() + + # Cohere returns a 'message' key when there is an error + # in the response. + if "message" in response: + if self._strict: + raise ValueError(f"API call failed: {response}.") else: - api_responses.append(srsly.json_dumps(response)) - else: - api_responses.append(srsly.json_dumps(response)) - return api_responses + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + + return response + + # Cohere API currently doesn't accept batch prompts, so we're making + # a request for each iteration. This approach can be prone to rate limit + # errors. In practice, you can adjust _max_request_time so that the + # timeout is larger. + responses = [_request({"prompt": prompt}) for prompt in prompts_for_doc] + for response in responses: + if "generations" in response: + for result in response["generations"]: + if "text" in result: + # Although you can set the number of completions in Cohere + # to be greater than 1, we only need to return a single value. + # In this case, we will just return the very first output. + api_responses.append(result["text"]) + break + else: + api_responses.append(srsly.json_dumps(response)) + else: + api_responses.append(srsly.json_dumps(response)) + + all_api_responses.append(api_responses) + + return all_api_responses @classmethod def get_model_names(cls) -> Tuple[str, ...]: diff --git a/spacy_llm/models/rest/cohere/registry.py b/spacy_llm/models/rest/cohere/registry.py index 3279bf4f..06adeedc 100644 --- a/spacy_llm/models/rest/cohere/registry.py +++ b/spacy_llm/models/rest/cohere/registry.py @@ -17,7 +17,7 @@ def cohere_command( max_tries: int = Cohere.DEFAULT_MAX_TRIES, interval: float = Cohere.DEFAULT_INTERVAL, max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Cohere instance for 'command' model using REST to prompt API. name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index cdb46170..dfe9be25 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -30,7 +30,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: pass - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # Assume time penalty for API calls. time.sleep(NoOpModel._CALL_TIMEOUT) return [_NOOP_RESPONSE] * len(list(prompts)) diff --git a/spacy_llm/models/rest/noop/registry.py b/spacy_llm/models/rest/noop/registry.py index bd393776..4050906b 100644 --- a/spacy_llm/models/rest/noop/registry.py +++ b/spacy_llm/models/rest/noop/registry.py @@ -5,7 +5,7 @@ @registry.llm_models("spacy.NoOp.v1") -def noop() -> Callable[[Iterable[str]], Iterable[str]]: +def noop() -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns NoOpModel. RETURNS (Callable[[Iterable[str]], Iterable[str]]]): NoOp model instance for test purposes. """ diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index 032ca462..b8bbdae3 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -74,72 +74,81 @@ def _verify_auth(self) -> None: f"The specified model '{self._name}' is not available. Choices are: {sorted(set(models))}" ) - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { **self._credentials, "Content-Type": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) - - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=self._endpoint, - headers=headers, - json={**json_data, **self._config, "model": self._name}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - raise ValueError( - f"Request to OpenAI API failed: {res_content.get('error', {}).get('message', str(res_content))}" - ) from ex - responses = r.json() - - if "error" in responses: - if self._strict: - raise ValueError(f"API call failed: {responses}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(responses)] * len(prompts)} - - return responses - - if self._endpoint == Endpoints.CHAT: - # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests. - for prompt in prompts: - responses = _request( - {"messages": [{"role": "user", "content": prompt}]} + all_api_responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=self._endpoint, + headers=headers, + json={**json_data, **self._config, "model": self._name}, + timeout=self._max_request_time, ) - if "error" in responses: - return responses["error"] + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + raise ValueError( + f"Request to OpenAI API failed: {res_content.get('error', {}).get('message', str(res_content))}" + ) from ex + responses = r.json() - # Process responses. - assert len(responses["choices"]) == 1 - response = responses["choices"][0] - api_responses.append( - response.get("message", {}).get( - "content", srsly.json_dumps(response) + if "error" in responses: + if self._strict: + raise ValueError(f"API call failed: {responses}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(responses)] + * len(prompts_for_doc) + } + + return responses + + if self._endpoint == Endpoints.CHAT: + # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual + # requests. + for prompt in prompts_for_doc: + responses = _request( + {"messages": [{"role": "user", "content": prompt}]} ) - ) + if "error" in responses: + return responses["error"] + + # Process responses. + assert len(responses["choices"]) == 1 + response = responses["choices"][0] + api_responses.append( + response.get("message", {}).get( + "content", srsly.json_dumps(response) + ) + ) + + elif self._endpoint == Endpoints.NON_CHAT: + responses = _request({"prompt": prompts_for_doc}) + if "error" in responses: + return responses["error"] + assert len(responses["choices"]) == len(prompts_for_doc) - elif self._endpoint == Endpoints.NON_CHAT: - responses = _request({"prompt": prompts}) - if "error" in responses: - return responses["error"] - assert len(responses["choices"]) == len(prompts) + for response in responses["choices"]: + if "text" in response: + api_responses.append(response["text"]) + else: + api_responses.append(srsly.json_dumps(response)) - for response in responses["choices"]: - if "text" in response: - api_responses.append(response["text"]) - else: - api_responses.append(srsly.json_dumps(response)) + all_api_responses.append(api_responses) - return api_responses + return all_api_responses @staticmethod def _get_context_lengths() -> Dict[str, int]: diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 1a4c4fd7..8a43e95c 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -31,7 +31,7 @@ def openai_gpt_4_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -61,7 +61,7 @@ def openai_gpt_4( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -95,7 +95,7 @@ def openai_gpt_3_5_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -133,7 +133,7 @@ def openai_gpt_3_5( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -169,7 +169,7 @@ def openai_text_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -199,7 +199,7 @@ def openai_text_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -229,7 +229,7 @@ def openai_code_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'code-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -257,7 +257,7 @@ def openai_code_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'code-davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -287,7 +287,7 @@ def openai_text_curie_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -315,7 +315,7 @@ def openai_text_curie( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -345,7 +345,7 @@ def openai_text_babbage_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -373,7 +373,7 @@ def openai_text_babbage( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -403,7 +403,7 @@ def openai_text_ada_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -431,7 +431,7 @@ def openai_text_ada( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'text-ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -461,7 +461,7 @@ def openai_davinci_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -489,7 +489,7 @@ def openai_davinci( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'davinci' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -519,7 +519,7 @@ def openai_curie_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -547,7 +547,7 @@ def openai_curie( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'curie' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -577,7 +577,7 @@ def openai_babbage_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -605,7 +605,7 @@ def openai_babbage( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'babbage' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -635,7 +635,7 @@ def openai_ada_v2( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. @@ -663,7 +663,7 @@ def openai_ada( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns OpenAI instance for 'ada' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. diff --git a/spacy_llm/models/rest/palm/model.py b/spacy_llm/models/rest/palm/model.py index d67ec3a5..547d972d 100644 --- a/spacy_llm/models/rest/palm/model.py +++ b/spacy_llm/models/rest/palm/model.py @@ -41,72 +41,83 @@ def _verify_auth(self) -> None: else: raise err - def __call__(self, prompts: Iterable[str]) -> Iterable[str]: + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: headers = { "Content-Type": "application/json", "Accept": "application/json", } - api_responses: List[str] = [] - prompts = list(prompts) url = self._endpoint.format( model=self._name, api_key=self._credentials["api_key"] ) + all_api_responses: List[List[str]] = [] - def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: - r = self.retry( - call_method=requests.post, - url=url, - headers=headers, - json={**json_data, **self._config}, - timeout=self._max_request_time, - ) - try: - r.raise_for_status() - except HTTPError as ex: - res_content = srsly.json_loads(r.content.decode("utf-8")) - # Include specific error message in exception. - error_message = res_content.get("error", {}).get("message", {}) - # Catching other types of HTTPErrors (e.g., "429: too many requests") - raise ValueError(f"Request to PaLM API failed: {error_message}") from ex - response = r.json() - - # PaLM returns a 'filter' key when a message was filtered due to safety concerns. - if "filters" in response: - if self._strict: - raise ValueError(f"API call failed: {response}.") - else: - assert isinstance(prompts, Sized) - return {"error": [srsly.json_dumps(response)] * len(prompts)} - return response - - # PaLM API currently doesn't accept batch prompts, so we're making - # a request for each iteration. This approach can be prone to rate limit - # errors. In practice, you can adjust _max_request_time so that the - # timeout is larger. - uses_chat = "chat" in self._name - responses = [ - _request( - { - "prompt": {"text": prompt} - if not uses_chat - else {"messages": [{"content": prompt}]} - } - ) - for prompt in prompts - ] - for response in responses: - if "candidates" in response: - # Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a - # single value. In this case, we will just return the very first output. - api_responses.append( - response["candidates"][0].get( - "content" if uses_chat else "output", srsly.json_dumps(response) - ) + for prompts_for_doc in prompts: + api_responses: List[str] = [] + prompts_for_doc = list(prompts_for_doc) + + def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: + r = self.retry( + call_method=requests.post, + url=url, + headers=headers, + json={**json_data, **self._config}, + timeout=self._max_request_time, ) - else: - api_responses.append(srsly.json_dumps(response)) + try: + r.raise_for_status() + except HTTPError as ex: + res_content = srsly.json_loads(r.content.decode("utf-8")) + # Include specific error message in exception. + error_message = res_content.get("error", {}).get("message", {}) + # Catching other types of HTTPErrors (e.g., "429: too many requests") + raise ValueError( + f"Request to PaLM API failed: {error_message}" + ) from ex + response = r.json() + + # PaLM returns a 'filter' key when a message was filtered due to safety concerns. + if "filters" in response: + if self._strict: + raise ValueError(f"API call failed: {response}.") + else: + assert isinstance(prompts_for_doc, Sized) + return { + "error": [srsly.json_dumps(response)] * len(prompts_for_doc) + } + + return response + + # PaLM API currently doesn't accept batch prompts, so we're making + # a request for each iteration. This approach can be prone to rate limit + # errors. In practice, you can adjust _max_request_time so that the + # timeout is larger. + uses_chat = "chat" in self._name + responses = [ + _request( + { + "prompt": {"text": prompt} + if not uses_chat + else {"messages": [{"content": prompt}]} + } + ) + for prompt in prompts_for_doc + ] + for response in responses: + if "candidates" in response: + # Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a + # single value. In this case, we will just return the very first output. + api_responses.append( + response["candidates"][0].get( + "content" if uses_chat else "output", + srsly.json_dumps(response), + ) + ) + else: + api_responses.append(srsly.json_dumps(response)) + + all_api_responses.append(api_responses) - return api_responses + return all_api_responses @staticmethod def _get_context_lengths() -> Dict[str, int]: diff --git a/spacy_llm/models/rest/palm/registry.py b/spacy_llm/models/rest/palm/registry.py index 9a56576a..ed2b396a 100644 --- a/spacy_llm/models/rest/palm/registry.py +++ b/spacy_llm/models/rest/palm/registry.py @@ -15,7 +15,7 @@ def palm_bison( max_tries: int = PaLM.DEFAULT_MAX_TRIES, interval: float = PaLM.DEFAULT_INTERVAL, max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, -) -> Callable[[Iterable[str]], Iterable[str]]: +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Google instance for PaLM Bison model using REST to prompt API. name (Literal["chat-bison-001", "text-bison-001"]): Model to use. config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 2b84912f..a0dd827e 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -209,19 +209,33 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: if isinstance(self._model, ModelWithContextLength): context_length = self._model.context_length + # todo obtain doc shards. should this be a separate step or be obtained from generate_prompts()? + # probably better to go with the latter, as sharding is linked to prompt generation. + # this means amending typing and return values everywhere with returned doc shards (rename + # generate_prompts()?). + # after that: + # - fix tee() handling of returned iterators (tee separately) + # - pass doc shards instead of noncached_doc_batch prompts_iters = tee( self._task.generate_prompts(noncached_doc_batch, context_length), - n_iters, + n_iters + 1, ) - responses_iters = tee(self._model(prompts_iters[0]), n_iters) - for prompt, response, doc in zip( + + responses_iters = tee( + self._model((elem[0] for elem in prompts_iters[0])), n_iters + ) + for prompts_and_shards, response, doc in zip( prompts_iters[1], responses_iters[1], noncached_doc_batch ): - logger.debug("Generated prompt for doc: %s\n%s", doc.text, prompt) + logger.debug( + "Generated prompt for doc: %s\n%s", doc.text, prompts_and_shards[0] + ) logger.debug("LLM response for doc: %s\n%s", doc.text, response) modified_docs = iter( - self._task.parse_responses(noncached_doc_batch, responses_iters[0]) + self._task.parse_responses( + (elem[1] for elem in prompts_iters[3]), responses_iters[0] + ) ) final_docs = [] @@ -240,8 +254,8 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: "llm_io", defaultdict(dict) ) llm_io = doc.user_data["llm_io"][self._name] - llm_io["prompt"] = str(next(prompts_iters[2])) - llm_io["response"] = str(next(responses_iters[2])) + llm_io["prompt"] = str(next(prompts_iters[2])[0]) + llm_io["response"] = str(next(responses_iters[2])[0]) self._cache.add(doc) final_docs.append(doc) diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index a1c848cb..de5e185b 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -54,12 +54,14 @@ def __init__( def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[Any]: + ) -> Iterable[Tuple[Iterable[Any], Iterable[Doc]]]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. ontext_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. - RETURNS (Iterable[Any]): Iterable with one prompt per doc. + RETURNS (Iterable[Tuple[Iterable[Any], Iterable[Doc]]]): Iterable with one to n prompts per doc (multiple + prompts in case of multiple shards) and the corresponding shards. The relationship between shard and prompt + is 1:1. """ environment = jinja2.Environment() _template = environment.from_string(self._template) @@ -83,11 +85,7 @@ def render_template(shard: Doc) -> str: if context_length is not None else [doc] ) - prompts = [ - render_template(shard) - for shard in (shards if isinstance(shards, list) else [shards]) - ] - yield prompts if len(prompts) > 1 else prompts[0] + yield [render_template(shard) for shard in shards], shards def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data @@ -106,12 +104,12 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: @abc.abstractmethod def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[Any] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[Any]] ) -> Iterable[Doc]: """ Parses LLM responses. - docs (Iterable[Doc]): Docs to map responses into. - responses ([Iterable[Any]]): LLM responses. + shards (Iterable[Iterable[Doc]]): Doc shards to map responses into. + responses ([Iterable[Iterable[Any]]]): LLM responses per doc. RETURNS (Iterable[Doc]]): Updated docs. """ diff --git a/spacy_llm/tasks/lemma/parser.py b/spacy_llm/tasks/lemma/parser.py index 9505f9d1..086a1eff 100644 --- a/spacy_llm/tasks/lemma/parser.py +++ b/spacy_llm/tasks/lemma/parser.py @@ -6,19 +6,25 @@ def parse_responses_v1( - task: LemmaTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[List[str]]]: + task: LemmaTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[List[str]]]]: """Parses LLM responses for spacy.Lemma.v1. task (LemmaTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. RETURNS (Iterable[List[str]]): Lists of 2-lists (token: lemmatized token) per doc/response. """ - for prompt_response in responses: - yield [ - [pr_part.strip() for pr_part in pr.split(":")] - for pr in prompt_response.replace("Lemmatized text:", "") - .replace("'''", "") - .strip() - .split("\n") - ] + for responses_for_doc in responses: + results_for_doc: List[List[List[str]]] = [] + for response in responses_for_doc: + results_for_doc.append( + [ + [pr_part.strip() for pr_part in pr.split(":")] + for pr in response.replace("Lemmatized text:", "") + .replace("'''", "") + .strip() + .split("\n") + ] + ) + + yield results_for_doc diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index dad38ce5..5a67c69a 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -44,21 +44,28 @@ def __init__( self._scorer = scorer def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, lemmas in zip(docs, self._parse_responses(self, docs, responses)): - tokens = [token for token in doc] - # If numbers of tokens recognized by spaCy and returned by LLM don't match, we don't attempt a partial - # match. - if len(tokens) != len(lemmas): - yield doc + for shards_for_doc, lemmas_for_doc in zip( + shards, self._parse_responses(self, shards, responses) + ): + updated_shards_for_doc: List[Doc] = [] - # Assign lemmas. - for token, lemma_info in zip(tokens, lemmas): - if len(lemma_info) > 0: - token.lemma_ = lemma_info[1] + for shard, lemmas in zip(shards_for_doc, lemmas_for_doc): + tokens = [token for token in shard] + # If numbers of tokens recognized by spaCy and returned by LLM don't match, we don't attempt a partial + # match. + if len(tokens) != len(lemmas): + updated_shards_for_doc.append(shard) - yield doc + # Assign lemmas. + for token, lemma_info in zip(tokens, lemmas): + if len(lemma_info) > 0: + token.lemma_ = lemma_info[1] + + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(updated_shards_for_doc) def initialize( self, diff --git a/spacy_llm/tasks/noop.py b/spacy_llm/tasks/noop.py index dc31ce40..9d11ee64 100644 --- a/spacy_llm/tasks/noop.py +++ b/spacy_llm/tasks/noop.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional +from typing import Iterable, Optional, Tuple from spacy.tokens import Doc @@ -15,15 +15,15 @@ def make_noop_task(): class NoopTask: def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[str]: - for _ in docs: - yield _NOOP_PROMPT + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: + for doc in docs: + yield [_NOOP_PROMPT], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - # Not doing anything - return docs + # Grab the first shard per doc + return [list(shards_for_doc)[0] for shards_for_doc in shards] @property def prompt_template(self) -> str: diff --git a/spacy_llm/tasks/rel/parser.py b/spacy_llm/tasks/rel/parser.py index 890a6aac..3f79b31e 100644 --- a/spacy_llm/tasks/rel/parser.py +++ b/spacy_llm/tasks/rel/parser.py @@ -9,28 +9,32 @@ def parse_responses_v1( - task: RELTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[RelationItem]]: + task: RELTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[RelationItem]]]: """Parses LLM responses for spacy.REL.v1. task (RELTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[RelationItem]]): List of RelationItem instances per doc/response. + docs (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[RelationItem]]]): List of RelationItem instances per shard/response. """ - for response, doc in zip(responses, docs): - relations: List[RelationItem] = [] - for line in response.strip().split("\n"): - try: - rel_item = RelationItem.parse_raw(line) - if 0 <= rel_item.dep < len(doc.ents) and 0 <= rel_item.dest < len( - doc.ents - ): - relations.append(rel_item) - except ValidationError: - msg.warn( - "Validation issue", - line, - show=task.verbose, - ) + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[RelationItem]] = [] + for response, doc in zip(responses_for_doc, shards_for_doc): + relations: List[RelationItem] = [] + for line in response.strip().split("\n"): + try: + rel_item = RelationItem.parse_raw(line) + if 0 <= rel_item.dep < len(doc.ents) and 0 <= rel_item.dest < len( + doc.ents + ): + relations.append(rel_item) + except ValidationError: + msg.warn( + "Validation issue", + line, + show=task.verbose, + ) - yield relations + results_for_doc.append(relations) + + yield results_for_doc diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index 13fa5a9e..a109599e 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -92,13 +92,19 @@ def _preannotate(doc: Union[Doc, FewshotExample]) -> str: return text def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_extension(self._field) - for doc, rel_items in zip(docs, self._parse_responses(self, docs, responses)): - doc._.rel = rel_items - yield doc + for shards_for_doc, rel_items_for_doc in zip( + shards, self._parse_responses(self, shards, responses) + ): + updated_shards_for_doc: List[Doc] = [] + for shard, rel_items in zip(shards_for_doc, rel_items_for_doc): + shard._.rel = rel_items + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(updated_shards_for_doc) def initialize( self, diff --git a/spacy_llm/tasks/sentiment/parser.py b/spacy_llm/tasks/sentiment/parser.py index 8365dab0..5e4ba679 100644 --- a/spacy_llm/tasks/sentiment/parser.py +++ b/spacy_llm/tasks/sentiment/parser.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional +from typing import Iterable, List, Optional from spacy.tokens import Doc @@ -6,16 +6,24 @@ def parse_responses_v1( - task: SentimentTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[Optional[float]]: + task: SentimentTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[Optional[float]]]: """Parses LLM responses for spacy.Sentiment.v1. task (SentimentTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[Optional[float]]): Sentiment score per doc/response. None on parsing error. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[Optional[float]]]): Sentiment score per shard/response. None on parsing error. """ - for prompt_response in responses: - try: - yield float("".join(prompt_response.replace("Answer:", "").strip().split())) - except ValueError: - yield None + for responses_for_doc in responses: + results_for_doc: List[Optional[float]] = [] + for response in responses_for_doc: + try: + results_for_doc.append( + float("".join(response.replace("Answer:", "").strip().split())) + ) + except ValueError: + results_for_doc.append(None) + + yield results_for_doc diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 9ab1633d..9c42e1da 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -68,19 +68,23 @@ def initialize( ) def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_doc_extension() + shards: List[Doc] = [] - for doc, sentiment_score in zip( - docs, self._parse_responses(self, docs, responses) - ): - try: - setattr(doc._, self._field, sentiment_score) - except ValueError: - setattr(doc._, self._field, None) + for responses_for_doc in responses: + for shard, sentiment_score in zip( + docs, self._parse_responses(self, docs, responses_for_doc) + ): + try: + setattr(shard._, self._field, sentiment_score) + except ValueError: + setattr(shard._, self._field, None) - yield doc + shards.append(shard) + + yield self._shard_reducer(shards) @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/span/parser.py b/spacy_llm/tasks/span/parser.py index 467dcbbc..fd0c389e 100644 --- a/spacy_llm/tasks/span/parser.py +++ b/spacy_llm/tasks/span/parser.py @@ -35,35 +35,40 @@ def _format_response( def parse_responses( - task: SpanTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: SpanTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[Span]]]: """Parses LLM responses for Span tasks. task (SpanTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[Span]): Parsed spans per doc/response. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[Span]]]): Parsed spans per shard/response. """ - for doc, prompt_response in zip(docs, responses): - spans = [] - for label, phrases in _format_response( - prompt_response, task._normalizer, task._label_dict - ): - # For each phrase, find the substrings in the text - # and create a Span - offsets = find_substrings( - doc.text, - phrases, - case_sensitive=task._case_sensitive_matching, - single_match=task._single_match, - ) - for start, end in offsets: - span = doc.char_span( - start, end, alignment_mode=task._alignment_mode, label=label + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[Span]] = [] + + for shard, response in zip(shards_for_doc, responses_for_doc): + spans = [] + for label, phrases in _format_response( + response, task._normalizer, task._label_dict + ): + # For each phrase, find the substrings in the text + # and create a Span + offsets = find_substrings( + shard.text, + phrases, + case_sensitive=task._case_sensitive_matching, + single_match=task._single_match, ) - if span is not None: - spans.append(span) + for start, end in offsets: + span = shard.char_span( + start, end, alignment_mode=task._alignment_mode, label=label + ) + if span is not None: + spans.append(span) - yield spans + results_for_doc.append(spans) + + yield results_for_doc def _extract_span_reasons_cot(task: SpanTask, response: str) -> List[SpanReason]: @@ -152,19 +157,23 @@ def _find_spans_cot( def parse_responses_cot( - task: SpanTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: SpanTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] +) -> Iterable[Iterable[List[Span]]]: """Since we provide entities in a numbered list, we expect the LLM to output entities in the order they occur in the text. This parse function now incrementally finds substrings in the text and tracks the last found span's start character to ensure we don't overwrite previously found spans. task (SpanTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[Span]]): Spans to assign per doc. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[List[Span]]]): Spans to assign per shard. """ - for doc, llm_response in zip(docs, responses): - span_reasons = _extract_span_reasons_cot(task, llm_response) - spans = _find_spans_cot(task, doc, span_reasons) - yield spans + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[List[Span]] = [] + + for shard, response in zip(shards_for_doc, responses_for_doc): + span_reasons = _extract_span_reasons_cot(task, response) + results_for_doc.append(_find_spans_cot(task, shard, span_reasons)) + + yield results_for_doc diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index c42ee373..5794a1db 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -100,11 +100,17 @@ def assign_spans( raise NotImplementedError() def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, spans in zip(docs, self._parse_responses(self, docs, responses)): - self.assign_spans(doc, spans) - yield doc + + for shards_for_doc, spans_for_doc in zip( + shards, self._parse_responses(self, shards, responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, spans in zip(shards_for_doc, spans_for_doc): + self.assign_spans(shard, spans) + + yield self._shard_reducer(shards_for_doc) @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/summarization/parser.py b/spacy_llm/tasks/summarization/parser.py index 0af52f6e..5f9f34cb 100644 --- a/spacy_llm/tasks/summarization/parser.py +++ b/spacy_llm/tasks/summarization/parser.py @@ -1,4 +1,4 @@ -from typing import Iterable +from typing import Iterable, List from spacy.tokens import Doc @@ -6,13 +6,19 @@ def parse_responses_v1( - task: SummarizationTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[str]: + task: SummarizationTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[str]]: """Parses LLM responses for spacy.Summarization.v1. task (SummarizationTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[str]): Summary per doc/response. + docs (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[str]]): Summary per shard/response. """ - for prompt_response in responses: - yield prompt_response.replace("'''", "").strip() + for responses_for_doc in responses: + results_for_doc: List[str] = [] + for response in responses_for_doc: + results_for_doc.append(response.replace("'''", "").strip()) + + yield responses_for_doc diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 6b730fc9..69ed68c9 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -92,11 +92,18 @@ def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return {"max_n_words": self._max_n_words} def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, summary in zip(docs, self._parse_responses(self, docs, responses)): - setattr(doc._, self._field, summary) - yield doc + shards: List[Doc] = [] + + for responses_for_doc in responses: + for shard, summary in zip( + docs, self._parse_responses(self, docs, responses_for_doc) + ): + setattr(shard._, self._field, summary) + shards.append(shard) + + yield self._shard_reducer(shards) @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/textcat/parser.py b/spacy_llm/tasks/textcat/parser.py index 24228f7f..ee8f9ddc 100644 --- a/spacy_llm/tasks/textcat/parser.py +++ b/spacy_llm/tasks/textcat/parser.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable +from typing import Dict, Iterable, List from spacy.tokens import Doc from wasabi import msg @@ -7,40 +7,47 @@ def parse_responses_v1_v2_v3( - task: TextCatTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[Dict[str, float]]: + task: TextCatTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[Dict[str, float]]]: """Parses LLM responses for spacy.TextCat.v1, v2 and v3 task (LemmaTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Dict[str, float]): TextCat scores per class. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[Iterable[Dict[str, float]]]): TextCat scores per shard and class. """ - for response in responses: - categories: Dict[str, float] - response = response.strip() - if task.use_binary: - # Binary classification: We only have one label - label: str = list(task.label_dict.values())[0] - score = 1.0 if response.upper() == "POS" else 0.0 - categories = {label: score} - else: - # Multilabel classification - categories = {label: 0.0 for label in task.label_dict.values()} - - pred_labels = response.split(",") - if task.exclusive_classes and len(pred_labels) > 1: - # Don't use anything but raise a debug message - # Don't raise an error. Let user abort if they want to. - msg.text( - f"LLM returned multiple labels for this exclusive task: {pred_labels}.", - " Will store an empty label instead.", - show=task.verbose, - ) - pred_labels = [] - - for pred in pred_labels: - if task.normalizer(pred.strip()) in task.label_dict: - category = task.label_dict[task.normalizer(pred.strip())] - categories[category] = 1.0 - - yield categories + for response_for_doc in responses: + results_for_doc: List[Dict[str, float]] = [] + + for response in response_for_doc: + categories: Dict[str, float] + response = response.strip() + if task.use_binary: + # Binary classification: We only have one label + label: str = list(task.label_dict.values())[0] + score = 1.0 if response.upper() == "POS" else 0.0 + categories = {label: score} + else: + # Multilabel classification + categories = {label: 0.0 for label in task.label_dict.values()} + + pred_labels = response.split(",") + if task.exclusive_classes and len(pred_labels) > 1: + # Don't use anything but raise a debug message + # Don't raise an error. Let user abort if they want to. + msg.text( + f"LLM returned multiple labels for this exclusive task: {pred_labels}.", + " Will store an empty label instead.", + show=task.verbose, + ) + pred_labels = [] + + for pred in pred_labels: + if task.normalizer(pred.strip()) in task.label_dict: + category = task.label_dict[task.normalizer(pred.strip())] + categories[category] = 1.0 + + results_for_doc.append(categories) + + yield results_for_doc diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index a4af9b78..074856fe 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -98,11 +98,18 @@ def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: } def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for doc, cats in zip(docs, self._parse_responses(self, docs, responses)): - doc.cats = cats - yield doc + for shards_for_doc, cats_for_doc in zip( + shards, self._parse_responses(self, shards, responses) + ): + updated_shards_for_doc: List[Doc] = [] + + for shard, cats in zip(shards_for_doc, cats_for_doc): + shard.cats = cats + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(updated_shards_for_doc) def scorer( self, diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index 78d60d1e..c153f170 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -82,6 +82,6 @@ def map_doc_to_shards( return shards else: - return doc + return [doc] return map_doc_to_shards diff --git a/spacy_llm/tests/conftest.py b/spacy_llm/tests/conftest.py index 2eda3409..7a64a074 100644 --- a/spacy_llm/tests/conftest.py +++ b/spacy_llm/tests/conftest.py @@ -42,7 +42,7 @@ def pytest_collection_modifyitems(config, items): @registry.llm_models("test.NoOpModel.v1") def noop_factory(output: str = ""): - def noop(prompts: Iterable[str]) -> Iterable[str]: - return [output] * len(list(prompts)) + def noop(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + return [[output]] * len(list(prompts)) return noop diff --git a/spacy_llm/tests/models/test_rest.py b/spacy_llm/tests/models/test_rest.py index 08044e3b..488479b7 100644 --- a/spacy_llm/tests/models/test_rest.py +++ b/spacy_llm/tests/models/test_rest.py @@ -1,7 +1,7 @@ # mypy: ignore-errors import copy import re -from typing import Iterable, Optional +from typing import Iterable, Optional, Tuple import pytest import spacy @@ -24,12 +24,12 @@ class _CountTask: def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[str]: + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: for doc in docs: - yield _CountTask._PROMPT_TEMPLATE.format(text=doc.text) + yield _CountTask._PROMPT_TEMPLATE.format(text=doc.text), [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: return docs diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 72fb9bcc..d8204823 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -2,7 +2,7 @@ import sys import warnings from pathlib import Path -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Tuple import pytest import spacy @@ -158,8 +158,9 @@ def __init__(self): def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[int]: - return [0] * len(list(docs)) + ) -> Iterable[Tuple[Iterable[int], Iterable[Doc]]]: + for doc in docs: + yield [0], doc def parse_responses( self, docs: Iterable[Doc], responses: Iterable[float] diff --git a/spacy_llm/tests/test_cache.py b/spacy_llm/tests/test_cache.py index 0cc703d5..ef41a494 100644 --- a/spacy_llm/tests/test_cache.py +++ b/spacy_llm/tests/test_cache.py @@ -3,7 +3,7 @@ import re import time from pathlib import Path -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, Tuple import pytest import spacy @@ -213,11 +213,12 @@ def test_prompt_template_handling(): class NoopTask_NoPromptTemplate: def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[str]: - return [""] * len(list(docs)) + ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: + for doc in docs: + yield [""], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: return docs diff --git a/spacy_llm/tests/test_combinations.py b/spacy_llm/tests/test_combinations.py index 0cd986c5..1075183b 100644 --- a/spacy_llm/tests/test_combinations.py +++ b/spacy_llm/tests/test_combinations.py @@ -20,7 +20,7 @@ ["spacy.NER.v1", "spacy.NER.v3", "spacy.TextCat.v1"], ids=["ner.v1", "ner.v3", "textcat"], ) -@pytest.mark.parametrize("n_process", [1]) # , 2 +@pytest.mark.parametrize("n_process", [1, 2]) def test_combinations(model: str, task: str, n_process: int): """Randomly test combinations of models and tasks.""" ops = get_current_ops() diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 9de8093b..fafa8385 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -17,7 +17,9 @@ _ResponseType = Any _ParsedResponseType = Any -PromptExecutorType = Callable[[Iterable[_PromptType]], Iterable[_ResponseType]] +PromptExecutorType = Callable[ + [Iterable[Iterable[_PromptType]]], Iterable[_ResponseType] +] ExamplesConfigType = Union[ Iterable[Dict[str, Any]], Callable[[], Iterable[Dict[str, Any]]], None ] @@ -26,7 +28,7 @@ # Requires doc, context length and callable for rendering template from doc shard text. [Doc, int, Callable[[Doc], str]], # Returns each shard as a doc. - Union[Iterable[Doc], Doc], + Iterable[Doc], ] ShardReducer = Callable[[Iterable[Doc]], Doc] @@ -95,22 +97,26 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: class LLMTask(Protocol): def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[_PromptType]: + ) -> Iterable[Tuple[Iterable[_PromptType], Iterable[Doc]]]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. context_length (int): Context length for model this task is executed with. Needed for sharding and fusing docs, if the corresponding prompts exceed the context length. If None, context length is assumed to be infinite. - RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc. + RETURNS (Iterable[Tuple[Iterable[_PromptType], Iterable[Doc]]]): Iterable with one to n prompts per doc + (multiple prompts in case of multiple shards) and the corresponding shards. The relationship between shard + and prompt is 1:1. """ def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[_ResponseType] + self, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[_ResponseType]], ) -> Iterable[Doc]: """ Parses LLM responses. - docs (Iterable[Doc]): Docs to map responses into. - responses ([Iterable[_ResponseType]]): LLM responses. - RETURNS (Iterable[Doc]]): Updated docs. + docs (Iterable[Iterable[Doc]]): Doc shards to map responses into. + responses ([Iterable[Iterable[_ResponseType]]]): LLM responses. + RETURNS (Iterable[Doc]]): Updated (and fused) docs. """ @@ -132,8 +138,11 @@ class TaskResponseParser(Protocol[TaskContraT]): """Generic protocol for parsing functions with specific tasks.""" def __call__( - self, task: TaskContraT, docs: Iterable[Doc], responses: Iterable[Any] - ) -> Iterable[Any]: + self, + task: TaskContraT, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[Any]], + ) -> Iterable[Iterable[Any]]: ... @@ -204,25 +213,37 @@ def context_length(self) -> int: """ -def _do_args_match(out_arg: Iterable, in_arg: Iterable) -> bool: +def _do_args_match(out_arg: Iterable, in_arg: Iterable, nesting_level: int) -> bool: """Compares argument type of Iterables for compatibility. in_arg (Iterable): Input argument. out_arg (Iterable): Output argument. + nesting_level (int): Expected level of nesting in types. E. g. Iterable[Iterable[Any]] has a level of 2, + Iterable[Any] of 1. Note that this is assumed for all sub-types in out_arg and in_arg, as this is sufficient for + the current use case of checking the compatibility of task-to-model and model-to-parser communication flow. RETURNS (bool): True if type variables are of the same length and if type variables in out_arg are a subclass of (or the same class as) the type variables in in_arg. """ assert hasattr(out_arg, "__args__") and hasattr(in_arg, "__args__") + + out_types, in_types = out_arg, in_arg + for level in range(nesting_level): + out_types = ( + out_types.__args__[0] if level < (nesting_level - 1) else out_types.__args__ # type: ignore[attr-defined] + ) + in_types = ( + in_types.__args__[0] if level < (nesting_level - 1) else in_types.__args__ # type: ignore[attr-defined] + ) # Replace Any with object to make issubclass() check work. - out_type_vars = [arg if arg != Any else object for arg in out_arg.__args__] - in_type_vars = [arg if arg != Any else object for arg in in_arg.__args__] + out_types = [arg if arg != Any else object for arg in out_types] + in_types = [arg if arg != Any else object for arg in in_types] - if len(out_type_vars) != len(in_type_vars): + if len(out_types) != len(in_types): return False return all( [ issubclass(out_tv, in_tv) or issubclass(in_tv, out_tv) - for out_tv, in_tv in zip(out_type_vars, in_type_vars) + for out_tv, in_tv in zip(out_types, in_types) ] ) @@ -275,12 +296,13 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: ) if not hasattr(task, "generate_prompts"): raise ValueError( - "A task needs to have the following method: generate_prompts(self, docs: Iterable[Doc]) -> Iterable[Any]" + "A task needs to have the following method: generate_prompts(self, docs: Iterable[Doc]) -> " + "Iterable[Iterable[Any]]" ) if not hasattr(task, "parse_responses"): raise ValueError( "A task needs to have the following method: " - "parse_responses(self, docs: Iterable[Doc], responses: Iterable[Any]) -> Iterable[Doc]" + "parse_responses(self, docs: Iterable[Doc], responses: Iterable[Iterable[Any]]) -> Iterable[Doc]" ) type_hints = { @@ -334,14 +356,14 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: raise ValueError(msg) # Ensure that the template returns the same type as expected by the model - if not _do_args_match(template_output, model_input): # type: ignore[arg-type] + if not _do_args_match(template_output, model_input, 2): # type: ignore[arg-type] warnings.warn( f"Type returned from `task.generate_prompts()` (`{template_output}`) doesn't match type expected by " f"`model` (`{model_input}`)." ) # Ensure that the parser expects the same type as returned by the model - if not _do_args_match(model_output, parse_input): # type: ignore[arg-type] + if not _do_args_match(model_output, parse_input, 2): # type: ignore[arg-type] warnings.warn( f"Type returned from `model` (`{model_output}`) doesn't match type expected by " f"`task.parse_responses()` (`{parse_input}`)." diff --git a/usage_examples/tests/test_readme_examples.py b/usage_examples/tests/test_readme_examples.py index ea386750..72a74d2d 100644 --- a/usage_examples/tests/test_readme_examples.py +++ b/usage_examples/tests/test_readme_examples.py @@ -165,12 +165,14 @@ def test_example_5_custom_model(): import random @registry.llm_models("RandomClassification.v1") - def random_textcat(labels: str) -> Callable[[Iterable[str]], Iterable[str]]: + def random_textcat( + labels: str, + ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: labels = labels.split(",") - def _classify(prompts: Iterable[str]) -> Iterable[str]: - for _ in prompts: - yield random.choice(labels) + def _classify(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: + for prompts_for_doc in prompts: + yield [random.choice(labels) for _ in prompts_for_doc] return _classify From 23718fc372e2aa1b5e287fbbc39b03a2ad1cac92 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 27 Oct 2023 16:08:34 +0200 Subject: [PATCH 07/51] Fix shard & prompt flow. --- spacy_llm/pipeline/llm.py | 18 +++++------- spacy_llm/tasks/span/task.py | 2 +- spacy_llm/ty.py | 55 +++++++++++++++++++++++------------- 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index a0dd827e..698c5bb9 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -209,23 +209,19 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: if isinstance(self._model, ModelWithContextLength): context_length = self._model.context_length - # todo obtain doc shards. should this be a separate step or be obtained from generate_prompts()? - # probably better to go with the latter, as sharding is linked to prompt generation. - # this means amending typing and return values everywhere with returned doc shards (rename - # generate_prompts()?). - # after that: + # todo obtain doc shards. after that: # - fix tee() handling of returned iterators (tee separately) # - pass doc shards instead of noncached_doc_batch prompts_iters = tee( self._task.generate_prompts(noncached_doc_batch, context_length), n_iters + 1, ) - responses_iters = tee( self._model((elem[0] for elem in prompts_iters[0])), n_iters ) + for prompts_and_shards, response, doc in zip( - prompts_iters[1], responses_iters[1], noncached_doc_batch + prompts_iters[1], responses_iters[0], noncached_doc_batch ): logger.debug( "Generated prompt for doc: %s\n%s", doc.text, prompts_and_shards[0] @@ -234,11 +230,11 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: modified_docs = iter( self._task.parse_responses( - (elem[1] for elem in prompts_iters[3]), responses_iters[0] + (elem[1] for elem in prompts_iters[2]), responses_iters[1] ) ) - final_docs = [] + final_docs: List[Doc] = [] for i, doc in enumerate(docs): if is_cached[i]: cached_doc = self._cache[doc] @@ -254,8 +250,8 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: "llm_io", defaultdict(dict) ) llm_io = doc.user_data["llm_io"][self._name] - llm_io["prompt"] = str(next(prompts_iters[2])[0]) - llm_io["response"] = str(next(responses_iters[2])[0]) + llm_io["prompt"] = str(next(prompts_iters[-1])[0]) + llm_io["response"] = str(next(responses_iters[-1])) self._cache.add(doc) final_docs.append(doc) diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index 5794a1db..2a0b1069 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -102,7 +102,7 @@ def assign_spans( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - + shards = tuple(shards) for shards_for_doc, spans_for_doc in zip( shards, self._parse_responses(self, shards, responses) ): diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index fafa8385..694b0b7c 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -297,7 +297,7 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: if not hasattr(task, "generate_prompts"): raise ValueError( "A task needs to have the following method: generate_prompts(self, docs: Iterable[Doc]) -> " - "Iterable[Iterable[Any]]" + "Iterable[Tuple[Iterable[Any], Iterable[Doc]]]" ) if not hasattr(task, "parse_responses"): raise ValueError( @@ -311,9 +311,9 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: "model": _extract_model_call_signature(model), } - parse_input: Optional[Type] = None - model_input: Optional[Type] = None - model_output: Optional[Type] = None + parse_in: Optional[Type] = None + model_in: Optional[Type] = None + model_out: Optional[Type] = None # Validate the 'model' object if not (len(type_hints["model"]) == 2 and "return" in type_hints["model"]): @@ -322,9 +322,9 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: ) for k in type_hints["model"]: if k == "return": - model_output = type_hints["model"][k] + model_out = type_hints["model"][k] else: - model_input = type_hints["model"][k] + model_in = type_hints["model"][k] # validate the 'parse' object if not (len(type_hints["parse"]) == 3 and "return" in type_hints["parse"]): @@ -335,36 +335,53 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: # find the 'prompt_responses' var without assuming its name type_k = type_hints["parse"][k] if type_k != typing.Iterable[Doc]: - parse_input = type_hints["parse"][k] + parse_in = type_hints["parse"][k] - template_output = type_hints["template"]["return"] + template_out = type_hints["template"]["return"] # Check that all variables are Iterables. for var, msg in ( - (template_output, "`task.generate_prompts()` needs to return an `Iterable`."), + (template_out, "`task.generate_prompts()` needs to return an `Iterable`."), ( - model_input, + model_in, "The prompts variable in the 'model' needs to be an `Iterable`.", ), - (model_output, "The `model` function needs to return an `Iterable`."), + (model_out, "The `model` function needs to return an `Iterable`."), ( - parse_input, + parse_in, "`responses` in `task.parse_responses()` needs to be an `Iterable`.", ), ): - if not var != Iterable: + if not (hasattr(var, "_name") and var._name == "Iterable"): raise ValueError(msg) + # Ensure that template/prompt generator output is Iterable of 2-Tuple, the second of which fits doc shards type. + template_out_type = template_out.__args__[0] + if not ( + hasattr(template_out_type, "_name") + and template_out_type._name == "Tuple" + and len(template_out_type.__args__) == 2 + ): + warnings.warn( + f"Type in `Iterable` returned from `task.generate_prompts()` (`{template_out_type}`) has to be a 2-tuple " + f"(prompts, doc shards)." + ) + template_out_prompt_type = template_out_type.__args__[0] + # Ensure that the template returns the same type as expected by the model - if not _do_args_match(template_output, model_input, 2): # type: ignore[arg-type] + assert hasattr(model_in, "__args__") + assert model_in is not None + if not _do_args_match( + template_out_prompt_type, model_in.__args__[0], 1 + ): # type: ignore[arg-type] warnings.warn( - f"Type returned from `task.generate_prompts()` (`{template_output}`) doesn't match type expected by " - f"`model` (`{model_input}`)." + f"First type in `Iterable[Tuple[...]] returned from `task.generate_prompts()` " + f"(`{template_out_prompt_type}`) doesn't match type expected by `model` (`{model_in}`)." ) # Ensure that the parser expects the same type as returned by the model - if not _do_args_match(model_output, parse_input, 2): # type: ignore[arg-type] + if not _do_args_match(model_out, parse_in, 2): # type: ignore[arg-type] warnings.warn( - f"Type returned from `model` (`{model_output}`) doesn't match type expected by " - f"`task.parse_responses()` (`{parse_input}`)." + f"Type returned from `model` (`{model_out}`) doesn't match type expected by " + f"`task.parse_responses()` (`{parse_in}`)." ) From 7ce670d59201ad5aedae73d346b6d0a9ebb14836 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 27 Oct 2023 16:08:39 +0200 Subject: [PATCH 08/51] Fix shard & prompt flow. --- spacy_llm/ty.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 694b0b7c..221785f7 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -369,7 +369,6 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: template_out_prompt_type = template_out_type.__args__[0] # Ensure that the template returns the same type as expected by the model - assert hasattr(model_in, "__args__") assert model_in is not None if not _do_args_match( template_out_prompt_type, model_in.__args__[0], 1 From 0d75ea8ee897c2f5667c5a90ecfa753e31999f10 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 27 Oct 2023 16:17:50 +0200 Subject: [PATCH 09/51] Remove todo comments. --- spacy_llm/pipeline/llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 698c5bb9..6a520013 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -209,9 +209,6 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: if isinstance(self._model, ModelWithContextLength): context_length = self._model.context_length - # todo obtain doc shards. after that: - # - fix tee() handling of returned iterators (tee separately) - # - pass doc shards instead of noncached_doc_batch prompts_iters = tee( self._task.generate_prompts(noncached_doc_batch, context_length), n_iters + 1, From 9da7098b633b22a7d4cbc8a678c58ec2b8a81f72 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 27 Oct 2023 18:08:24 +0200 Subject: [PATCH 10/51] Fix Anthropic, Cohere, NoOp model tests. --- spacy_llm/models/rest/noop/model.py | 2 +- spacy_llm/tasks/span/task.py | 5 +++-- spacy_llm/tasks/textcat/task.py | 4 +++- spacy_llm/tests/models/test_anthropic.py | 6 ++++-- spacy_llm/tests/models/test_cohere.py | 12 ++++++++---- 5 files changed, 19 insertions(+), 10 deletions(-) diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index dfe9be25..7a6a2111 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -33,7 +33,7 @@ def _verify_auth(self) -> None: def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # Assume time penalty for API calls. time.sleep(NoOpModel._CALL_TIMEOUT) - return [_NOOP_RESPONSE] * len(list(prompts)) + return [[_NOOP_RESPONSE]] * len(list(prompts)) @staticmethod def _get_context_lengths() -> Dict[str, int]: diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index 2a0b1069..c313bfc7 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -1,4 +1,5 @@ import abc +from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union from typing import cast @@ -102,9 +103,9 @@ def assign_spans( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards = tuple(shards) + shards_teed = tee(shards, 2) for shards_for_doc, spans_for_doc in zip( - shards, self._parse_responses(self, shards, responses) + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): shards_for_doc = list(shards_for_doc) for shard, spans in zip(shards_for_doc, spans_for_doc): diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 074856fe..3370966d 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -1,3 +1,4 @@ +from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language @@ -100,8 +101,9 @@ def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: + shards_teed = tee(shards, 2) for shards_for_doc, cats_for_doc in zip( - shards, self._parse_responses(self, shards, responses) + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): updated_shards_for_doc: List[Doc] = [] diff --git a/spacy_llm/tests/models/test_anthropic.py b/spacy_llm/tests/models/test_anthropic.py index 61f1da77..39fac3eb 100644 --- a/spacy_llm/tests/models/test_anthropic.py +++ b/spacy_llm/tests/models/test_anthropic.py @@ -24,9 +24,11 @@ def test_anthropic_api_response_is_correct(): prompt = "Count the number of characters in this string: hello" num_prompts = 3 - responses = anthropic(prompts=[prompt] * num_prompts) + responses = anthropic(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external diff --git a/spacy_llm/tests/models/test_cohere.py b/spacy_llm/tests/models/test_cohere.py index 4b555aa4..3aad3c23 100644 --- a/spacy_llm/tests/models/test_cohere.py +++ b/spacy_llm/tests/models/test_cohere.py @@ -21,9 +21,11 @@ def test_cohere_api_response_is_correct(): ) prompt = "Count the number of characters in this string: hello" num_prompts = 3 # arbitrary number to check multiple inputs - responses = cohere(prompts=[prompt] * num_prompts) + responses = cohere(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external @@ -48,9 +50,11 @@ def test_cohere_api_response_n_generations(): prompt = "Count the number of characters in this string: hello" num_prompts = 3 - responses = cohere(prompts=[prompt] * num_prompts) + responses = cohere(prompts=[[prompt]] * num_prompts) for response in responses: - assert isinstance(response, str) + assert isinstance(response, list) + assert len(response) == 1 + assert isinstance(response[0], str) @pytest.mark.external From f3684129e84f097b89fa480402ad29310686d075 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 31 Oct 2023 16:22:34 +0100 Subject: [PATCH 11/51] Fix test_llm_pipe(). --- spacy_llm/tests/pipeline/test_llm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index d8204823..4e76a66d 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -64,8 +64,8 @@ def test_llm_pipe(nlp: Language, n_process: int): for doc in docs: llm_io = doc.user_data["llm_io"] - assert llm_io["llm"]["prompt"] == _NOOP_PROMPT - assert llm_io["llm"]["response"] == _NOOP_RESPONSE + assert llm_io["llm"]["prompt"] == str([_NOOP_PROMPT]) + assert llm_io["llm"]["response"] == str([_NOOP_RESPONSE]) @pytest.mark.parametrize("n_process", [1, 2]) @@ -158,14 +158,14 @@ def __init__(self): def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[Tuple[Iterable[int], Iterable[Doc]]]: + ) -> Iterable[Tuple[Iterable[Iterable[int]], Iterable[Doc]]]: for doc in docs: - yield [0], doc + yield [[0]], doc def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[float] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[float]] ) -> Iterable[Doc]: - return docs + return list(shards)[0] nlp = spacy.blank("en") with pytest.warns(UserWarning) as record: From b1f111ddbe75b81bc76d1cb4a13033843edcf747 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 14:16:09 +0100 Subject: [PATCH 12/51] Fix type checking test. --- spacy_llm/tests/pipeline/test_llm.py | 12 ++++++------ spacy_llm/ty.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 4e76a66d..c2a6137d 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -158,9 +158,9 @@ def __init__(self): def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None - ) -> Iterable[Tuple[Iterable[Iterable[int]], Iterable[Doc]]]: + ) -> Iterable[Tuple[Iterable[int], Iterable[Doc]]]: for doc in docs: - yield [[0]], doc + yield [0], [doc] def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[float]] @@ -174,13 +174,13 @@ def parse_responses( assert len(record) == 2 assert ( str(record[0].message) - == "Type returned from `task.generate_prompts()` (`typing.Iterable[int]`) doesn't match type " - "expected by `model` (`typing.Iterable[str]`)." + == "First type in `Iterable[Tuple[...]] returned from `task.generate_prompts()` (`typing.Iterable[int]`) " + "doesn't match type expected by `model` (`typing.Iterable[str]`)." ) assert ( str(record[1].message) - == "Type returned from `model` (`typing.Iterable[str]`) doesn't match type " - "expected by `task.parse_responses()` (`typing.Iterable[float]`)." + == "Type returned from `model` (`typing.Iterable[typing.Iterable[str]]`) doesn't match type expected by " + "`task.parse_responses()` (`typing.Iterable[typing.Iterable[float]]`)." ) # Run with disabled type consistency validation. diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 221785f7..a5c5af62 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -375,7 +375,7 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: ): # type: ignore[arg-type] warnings.warn( f"First type in `Iterable[Tuple[...]] returned from `task.generate_prompts()` " - f"(`{template_out_prompt_type}`) doesn't match type expected by `model` (`{model_in}`)." + f"(`{template_out_prompt_type}`) doesn't match type expected by `model` (`{model_in.__args__[0]}`)." ) # Ensure that the parser expects the same type as returned by the model From 44a278796fe24f8bb977038987127e06ed4ef7ac Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 14:52:25 +0100 Subject: [PATCH 13/51] Fix span parsing tests. --- spacy_llm/tests/tasks/legacy/test_ner.py | 4 ++-- spacy_llm/tests/tasks/legacy/test_spancat.py | 4 ++-- spacy_llm/tests/tasks/test_lemma.py | 6 +++--- spacy_llm/tests/tasks/test_ner.py | 2 +- spacy_llm/tests/tasks/test_sentiment.py | 6 +++--- spacy_llm/tests/tasks/test_spancat.py | 4 ++-- spacy_llm/tests/tasks/test_summarization.py | 6 +++--- spacy_llm/tests/tasks/test_textcat.py | 10 +++++----- 8 files changed, 21 insertions(+), 21 deletions(-) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 6d8ff727..8ea28867 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -547,7 +547,7 @@ def test_jinja_template_rendering_with_examples(examples_path): examples = fewshot_reader(examples_path) llm_ner = make_ner_task_v2(labels=labels, examples=examples) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -611,7 +611,7 @@ def test_jinja_template_rendering_with_label_definitions(): "LOC": "Location definition", }, ) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() diff --git a/spacy_llm/tests/tasks/legacy/test_spancat.py b/spacy_llm/tests/tasks/legacy/test_spancat.py index 124dd94d..581de740 100644 --- a/spacy_llm/tests/tasks/legacy/test_spancat.py +++ b/spacy_llm/tests/tasks/legacy/test_spancat.py @@ -376,7 +376,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_spancat = make_spancat_task_v2(labels=labels, examples=None) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -420,7 +420,7 @@ def test_jinja_template_rendering_with_examples(examples_path): examples = fewshot_reader(examples_path) llm_spancat = make_spancat_task_v2(labels=labels, examples=examples) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() diff --git a/spacy_llm/tests/tasks/test_lemma.py b/spacy_llm/tests/tasks/test_lemma.py index ec30e3ce..87e7ad48 100644 --- a/spacy_llm/tests/tasks/test_lemma.py +++ b/spacy_llm/tests/tasks/test_lemma.py @@ -199,7 +199,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc(text) lemma_task = make_lemma_task(examples=None) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -242,7 +242,7 @@ def test_jinja_template_rendering_with_examples(examples_path): doc = nlp.make_doc(text) lemma_task = make_lemma_task(examples=fewshot_reader(examples_path)) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -334,7 +334,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc(text) lemma_task = make_lemma_task(template=template) - prompt = list(lemma_task.generate_prompts([doc]))[0] + prompt = list(lemma_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == f""" diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index 7d7c8c82..11b17e6d 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -395,7 +395,7 @@ def test_ner_labels( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index 4b2bd63b..799f8bc8 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -173,7 +173,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(examples=None) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -207,7 +207,7 @@ def test_jinja_template_rendering_with_examples(examples_path): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(examples=fewshot_reader(examples_path)) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -255,7 +255,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc(text) sentiment_task = make_sentiment_task(template=template) - prompt = list(sentiment_task.generate_prompts([doc]))[0] + prompt = list(sentiment_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == f""" diff --git a/spacy_llm/tests/tasks/test_spancat.py b/spacy_llm/tests/tasks/test_spancat.py index 97e5bd88..b16ed0d1 100644 --- a/spacy_llm/tests/tasks/test_spancat.py +++ b/spacy_llm/tests/tasks/test_spancat.py @@ -446,7 +446,7 @@ def test_jinja_template_rendering_without_examples(): nlp = spacy.blank("en") doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_spancat = make_spancat_task_v3(labels=labels) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -501,7 +501,7 @@ def test_jinja_template_rendering_with_examples(examples_path: Path): examples = fewshot_reader(examples_path) llm_spancat = make_spancat_task_v3(labels=labels, examples=examples) - prompt = list(llm_spancat.generate_prompts([doc]))[0] + prompt = list(llm_spancat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() diff --git a/spacy_llm/tests/tasks/test_summarization.py b/spacy_llm/tests/tasks/test_summarization.py index d8912b38..c51ddfb9 100644 --- a/spacy_llm/tests/tasks/test_summarization.py +++ b/spacy_llm/tests/tasks/test_summarization.py @@ -253,7 +253,7 @@ def test_jinja_template_rendering_without_examples(example_text): doc = nlp.make_doc(example_text) llm_ner = make_summarization_task(examples=None, max_n_words=10) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -294,7 +294,7 @@ def test_jinja_template_rendering_with_examples(examples_path, example_text): "The provided example 'Life is a quality th...' has a summary of length 28, but `max_n_words` == 20." ), ): - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -343,7 +343,7 @@ def test_external_template_actually_loads(example_text): doc = nlp.make_doc(example_text) llm_ner = make_summarization_task(template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ diff --git a/spacy_llm/tests/tasks/test_textcat.py b/spacy_llm/tests/tasks/test_textcat.py index b6b0d641..6cf4e819 100644 --- a/spacy_llm/tests/tasks/test_textcat.py +++ b/spacy_llm/tests/tasks/test_textcat.py @@ -380,7 +380,7 @@ def test_jinja_template_rendering_with_examples_for_binary(examples_path, binary examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -446,7 +446,7 @@ def test_jinja_template_rendering_with_examples_for_multilabel_exclusive( examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -513,7 +513,7 @@ def test_jinja_template_rendering_with_examples_for_multilabel_nonexclusive( examples=prompt_examples, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -592,7 +592,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Combine 2 cloves of garlic with soy sauce") llm_textcat = make_textcat_task_v3(labels=labels, template=template) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ @@ -680,7 +680,7 @@ def test_jinja_template_rendering_with_label_definitions(multilabel_excl): }, exclusive_classes=exclusive_classes, ) - prompt = list(llm_textcat.generate_prompts([doc]))[0] + prompt = list(llm_textcat.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ From 6d8cdc7a4b262172509133eac41fd67e77445878 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 15:09:00 +0100 Subject: [PATCH 14/51] Fix internal tests. --- spacy_llm/tests/tasks/legacy/test_ner.py | 12 ++++++------ spacy_llm/tests/tasks/legacy/test_spancat.py | 8 ++++---- spacy_llm/tests/tasks/test_ner.py | 16 ++++++++-------- spacy_llm/tests/tasks/test_rel.py | 8 ++++---- spacy_llm/tests/tasks/test_spancat.py | 8 ++++---- spacy_llm/tests/tasks/test_textcat.py | 8 ++++---- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 8ea28867..7b25a577 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -329,7 +329,7 @@ def test_ner_zero_shot_task(text, response, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -388,7 +388,7 @@ def test_ner_labels(response, normalizer, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -437,7 +437,7 @@ def test_ner_alignment(response, alignment_mode, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -488,7 +488,7 @@ def test_ner_matching(response, case_sensitive, single_match, gold_ents): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -504,7 +504,7 @@ def test_jinja_template_rendering_without_examples(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v2(labels=labels, examples=None) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -664,7 +664,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v2(labels=labels, template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() == """ diff --git a/spacy_llm/tests/tasks/legacy/test_spancat.py b/spacy_llm/tests/tasks/legacy/test_spancat.py index 581de740..bd6ead2f 100644 --- a/spacy_llm/tests/tasks/legacy/test_spancat.py +++ b/spacy_llm/tests/tasks/legacy/test_spancat.py @@ -201,7 +201,7 @@ def test_spancat_zero_shot_task(text, response, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -260,7 +260,7 @@ def test_spancat_labels(response, normalizer, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -309,7 +309,7 @@ def test_spancat_alignment(response, alignment_mode, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -360,7 +360,7 @@ def test_spancat_matching(response, case_sensitive, single_match, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index 11b17e6d..d9082d2e 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -451,7 +451,7 @@ def test_ner_alignment( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -502,7 +502,7 @@ def test_ner_matching( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_ner.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_ner.parse_responses([[doc_in]], [[response]]))[0] pred_ents = [(ent.text, ent.label_) for ent in doc_out.ents] assert pred_ents == gold_ents @@ -517,7 +517,7 @@ def test_jinja_template_rendering_without_examples(): nlp = spacy.blank("en") doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v3(labels=labels) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -562,7 +562,7 @@ def test_jinja_template_rendering_with_examples(examples_dir: Path, examples_fil doc = nlp.make_doc("Alice and Bob went to the supermarket") examples = fewshot_reader(examples_dir / examples_file) llm_ner = make_ner_task_v3(examples=examples, labels=labels) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -609,7 +609,7 @@ def test_jinja_template_rendering_with_label_definitions( "LOC": "Location definition", }, ) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert ( prompt.strip() @@ -690,7 +690,7 @@ def test_external_template_actually_loads(): doc = nlp.make_doc("Alice and Bob went to the supermarket") llm_ner = make_ner_task_v3(examples=[], labels=labels, template=template) - prompt = list(llm_ner.generate_prompts([doc]))[0] + prompt = list(llm_ner.generate_prompts([doc]))[0][0][0] assert prompt.strip().startswith("Here's the test template for the tests and stuff") @@ -935,7 +935,7 @@ def test_regression_span_task_response_parse( span_reasons = _extract_span_reasons_cot(ner_task, response) assert len(span_reasons) == len(gold_ents) - docs = list(ner_task.parse_responses([example_doc], [response])) + docs = list(ner_task.parse_responses([[example_doc]], [[response]])) assert len(docs) == 1 doc = docs[0] @@ -964,7 +964,7 @@ def test_regression_span_task_comma( ner_task = make_ner_task_v3(examples=[], labels=["ORG", "LOC"]) span_reasons = _extract_span_reasons_cot(ner_task, response) assert len(span_reasons) == len(gold_ents) - docs = list(ner_task.parse_responses([example_doc], [response])) + docs = list(ner_task.parse_responses([[example_doc]], [[response]])) assert len(docs) == 1 doc = docs[0] pred_ents = [(ent.text, ent.label_) for ent in doc.ents] diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index eb685f2e..48ac0eed 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -250,9 +250,9 @@ def test_incorrect_indexing(): len( list( task._parse_responses( - task, [doc], ['{"dep": 0, "dest": 0, "relation": "LivesIn"}'] + task, [[doc]], [['{"dep": 0, "dest": 0, "relation": "LivesIn"}']] ) - )[0] + )[0][0] ) == 1 ) @@ -260,9 +260,9 @@ def test_incorrect_indexing(): len( list( task._parse_responses( - task, [doc], ['{"dep": 0, "dest": 1, "relation": "LivesIn"}'] + task, [[doc]], [['{"dep": 0, "dest": 1, "relation": "LivesIn"}']] ) - )[0] + )[0][0] ) == 0 ) diff --git a/spacy_llm/tests/tasks/test_spancat.py b/spacy_llm/tests/tasks/test_spancat.py index b16ed0d1..1ba9c11e 100644 --- a/spacy_llm/tests/tasks/test_spancat.py +++ b/spacy_llm/tests/tasks/test_spancat.py @@ -257,7 +257,7 @@ def test_spancat_matching_shot_task(text: str, response: str, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list so we get what's inside - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -330,7 +330,7 @@ def test_spancat_labels( doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -382,7 +382,7 @@ def test_spancat_alignment(response, alignment_mode, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans @@ -431,7 +431,7 @@ def test_spancat_matching(response, case_sensitive, gold_spans): doc_in = nlp.make_doc(text) # Pass to the parser # Note: parser() returns a list - doc_out = list(llm_spancat.parse_responses([doc_in], [response]))[0] + doc_out = list(llm_spancat.parse_responses([[doc_in]], [[response]]))[0] pred_spans = [(span.text, span.label_) for span in doc_out.spans["sc"]] assert pred_spans == gold_spans diff --git a/spacy_llm/tests/tasks/test_textcat.py b/spacy_llm/tests/tasks/test_textcat.py index 6cf4e819..0e9b9828 100644 --- a/spacy_llm/tests/tasks/test_textcat.py +++ b/spacy_llm/tests/tasks/test_textcat.py @@ -318,7 +318,7 @@ def test_textcat_binary_labels_are_correct(text, response, expected_score): nlp = spacy.blank("en") doc = nlp(text) - pred = list(llm_textcat.parse_responses([doc], [response]))[0] + pred = list(llm_textcat.parse_responses([[doc]], [[response]]))[0] assert list(pred.cats.keys())[0] == label assert list(pred.cats.values())[0] == expected_score @@ -350,7 +350,7 @@ def test_textcat_multilabel_labels_are_correct( ) nlp = spacy.blank("en") doc = nlp.make_doc(text) - pred = list(llm_textcat.parse_responses([doc], [response]))[0] + pred = list(llm_textcat.parse_responses([[doc]], [[response]]))[0] # Take only those that have scores pred_cats = [cat for cat, score in pred.cats.items() if score == 1.0] assert set(pred_cats) == set(expected) @@ -636,9 +636,9 @@ def test_external_template_actually_loads(): def test_textcat_scoring(zeroshot_cfg_string, n_insults): @registry.llm_models("Dummy") def factory(): - def b(prompts: Iterable[str]) -> Iterable[str]: + def b(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: for _ in prompts: - yield "POS" + yield ["POS"] return b From e712f4105bf11c0dda1e59c5559c04fbbf76e059 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 16:20:02 +0100 Subject: [PATCH 15/51] Fix _CountTask. --- spacy_llm/tests/models/test_rest.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/spacy_llm/tests/models/test_rest.py b/spacy_llm/tests/models/test_rest.py index 488479b7..305732c6 100644 --- a/spacy_llm/tests/models/test_rest.py +++ b/spacy_llm/tests/models/test_rest.py @@ -26,12 +26,13 @@ def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: for doc in docs: - yield _CountTask._PROMPT_TEMPLATE.format(text=doc.text), [doc] + yield [_CountTask._PROMPT_TEMPLATE.format(text=doc.text)], [doc] def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - return docs + # Grab the first shard per doc + return [list(shards_for_doc)[0] for shards_for_doc in shards] @property def prompt_template(self) -> str: From 985fd68a345a5474b07cd08d82e89367db2c7001 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 16:45:32 +0100 Subject: [PATCH 16/51] Fix sentiment and summarization tasks and tests. --- spacy_llm/tasks/sentiment/task.py | 20 ++++++++++---------- spacy_llm/tasks/summarization/task.py | 17 +++++++++-------- spacy_llm/tests/tasks/test_sentiment.py | 2 +- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 9c42e1da..10c5f973 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -1,3 +1,4 @@ +from itertools import tee from typing import Callable, Iterable, List, Optional, Type from spacy.language import Language @@ -68,23 +69,22 @@ def initialize( ) def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_doc_extension() - shards: List[Doc] = [] + shards_teed = tee(shards, 2) - for responses_for_doc in responses: - for shard, sentiment_score in zip( - docs, self._parse_responses(self, docs, responses_for_doc) - ): + for shards_for_doc, scores_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, score in zip(shards_for_doc, scores_for_doc): try: - setattr(shard._, self._field, sentiment_score) + setattr(shard._, self._field, score) except ValueError: setattr(shard._, self._field, None) - shards.append(shard) - - yield self._shard_reducer(shards) + yield self._shard_reducer(shards_for_doc) @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 69ed68c9..5234d2c6 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -1,4 +1,5 @@ import warnings +from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language @@ -92,18 +93,18 @@ def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return {"max_n_words": self._max_n_words} def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[Iterable[str]] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards: List[Doc] = [] + shards_teed = tee(shards, 2) - for responses_for_doc in responses: - for shard, summary in zip( - docs, self._parse_responses(self, docs, responses_for_doc) - ): + for shards_for_doc, summaries_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, summary in zip(shards_for_doc, summaries_for_doc): setattr(shard._, self._field, summary) - shards.append(shard) - yield self._shard_reducer(shards) + yield self._shard_reducer(shards_for_doc) @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index 799f8bc8..034e09e4 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -144,7 +144,7 @@ def test_sentiment_predict(cfg_string, request): ("zeroshot_cfg_string", "sentiment_x"), ], ) -def test_lemma_io(cfg_string_field, request): +def test_sentiment_io(cfg_string_field, request): cfg_string, field = cfg_string_field cfg = request.getfixturevalue(cfg_string) orig_config = Config().from_str(cfg) From 98842a21331ca9dc8b0f2124003bb26e67739af4 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 17:04:27 +0100 Subject: [PATCH 17/51] Fix Azure connection URL. Fix Model test pings. --- spacy_llm/models/rest/anthropic/model.py | 2 +- spacy_llm/models/rest/azure/model.py | 4 ++-- spacy_llm/models/rest/azure/registry.py | 2 +- spacy_llm/models/rest/cohere/model.py | 2 +- spacy_llm/models/rest/palm/model.py | 2 +- spacy_llm/tasks/lemma/task.py | 5 ++++- spacy_llm/tasks/rel/task.py | 4 +++- 7 files changed, 13 insertions(+), 8 deletions(-) diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index 06013cb4..269b3209 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -40,7 +40,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: # Execute a dummy prompt. If the API setup is incorrect, we should fail at initialization time. try: - self(["test"]) + self([["test"]]) except ValueError as err: if "authentication_error" in str(err): warnings.warn( diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index 93d850b3..436deabe 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -50,7 +50,7 @@ def endpoint(self) -> str: return ( self._endpoint + ("" if self._endpoint.endswith("/") else "/") - + f"openai/deployments/{self._name}/{'' if self._model_type == ModelType.COMPLETION else 'chat/'}" + + f"openai/deployments/{self._deployment_name}/{'' if self._model_type == ModelType.COMPLETION else 'chat/'}" f"completions" ) @@ -73,7 +73,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["test"]) + self([["test"]]) except ValueError as err: raise err diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index d69117f8..850cc43f 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -21,7 +21,7 @@ def azure_openai( max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, api_version: str = "2023-05-15", ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: - """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. + """Returns Azure OpenAI instance for models deployed on Azure's OpenAI service using REST to prompt API. Docs on OpenAI models supported by Azure: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#model-summary-table-and-region-availability. diff --git a/spacy_llm/models/rest/cohere/model.py b/spacy_llm/models/rest/cohere/model.py index 202d55c4..8d06f350 100644 --- a/spacy_llm/models/rest/cohere/model.py +++ b/spacy_llm/models/rest/cohere/model.py @@ -29,7 +29,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["test"]) + self([["test"]]) except ValueError as err: if "invalid api token" in str(err): warnings.warn( diff --git a/spacy_llm/models/rest/palm/model.py b/spacy_llm/models/rest/palm/model.py index 547d972d..b1a2657d 100644 --- a/spacy_llm/models/rest/palm/model.py +++ b/spacy_llm/models/rest/palm/model.py @@ -31,7 +31,7 @@ def credentials(self) -> Dict[str, str]: def _verify_auth(self) -> None: try: - self(["What's 2+2?"]) + self([["What's 2+2?"]]) except ValueError as err: if "API key not valid" in str(err): warnings.warn( diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index 5a67c69a..208442c6 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -1,3 +1,4 @@ +from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy import Language @@ -46,8 +47,10 @@ def __init__( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: + shards_teed = tee(shards, 2) + for shards_for_doc, lemmas_for_doc in zip( - shards, self._parse_responses(self, shards, responses) + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): updated_shards_for_doc: List[Doc] = [] diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index a109599e..8bf38f3c 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -1,3 +1,4 @@ +from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from spacy.language import Language @@ -95,9 +96,10 @@ def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_extension(self._field) + shards_teed = tee(shards, 2) for shards_for_doc, rel_items_for_doc in zip( - shards, self._parse_responses(self, shards, responses) + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): updated_shards_for_doc: List[Doc] = [] for shard, rel_items in zip(shards_for_doc, rel_items_for_doc): From b54a3d92d91e6c15e985f2f097f01136ed1e4066 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 17:18:47 +0100 Subject: [PATCH 18/51] Fix Lemma parsing. --- spacy_llm/tasks/lemma/parser.py | 21 ++++++++++++++------- spacy_llm/tasks/lemma/task.py | 1 - spacy_llm/tests/tasks/test_lemma.py | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/spacy_llm/tasks/lemma/parser.py b/spacy_llm/tasks/lemma/parser.py index 086a1eff..d9ff7c1e 100644 --- a/spacy_llm/tasks/lemma/parser.py +++ b/spacy_llm/tasks/lemma/parser.py @@ -7,23 +7,30 @@ def parse_responses_v1( task: LemmaTask, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] -) -> Iterable[Iterable[List[List[str]]]]: +) -> Iterable[List[List[List[str]]]]: """Parses LLM responses for spacy.Lemma.v1. task (LemmaTask): Task instance. shards (Iterable[Iterable[Doc]]): Doc shards. responses (Iterable[Iterable[str]]): LLM responses. - RETURNS (Iterable[List[str]]): Lists of 2-lists (token: lemmatized token) per doc/response. + RETURNS (Iterable[List[List[List[str]]]]): Lists of 2-lists per token (token: lemmatized token) and shard/response + and doc. """ for responses_for_doc in responses: results_for_doc: List[List[List[str]]] = [] for response in responses_for_doc: + results_for_shard = [ + [pr_part.strip() for pr_part in pr.split(":")] + for pr in response.replace("Lemmatized text:", "") + .replace("'''", "") + .strip() + .split("\n") + ] results_for_doc.append( + # Malformed responses might have a length != 2, in which case they are discarded. [ - [pr_part.strip() for pr_part in pr.split(":")] - for pr in response.replace("Lemmatized text:", "") - .replace("'''", "") - .strip() - .split("\n") + result_for_token + for result_for_token in results_for_shard + if len(result_for_token) == 2 ] ) diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index 208442c6..56e4e43b 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -48,7 +48,6 @@ def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: shards_teed = tee(shards, 2) - for shards_for_doc, lemmas_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): diff --git a/spacy_llm/tests/tasks/test_lemma.py b/spacy_llm/tests/tasks/test_lemma.py index 87e7ad48..20d21618 100644 --- a/spacy_llm/tests/tasks/test_lemma.py +++ b/spacy_llm/tests/tasks/test_lemma.py @@ -141,8 +141,8 @@ def test_lemma_config(cfg_string, request): @pytest.mark.parametrize( "cfg_string", [ - "zeroshot_cfg_string", - "fewshot_cfg_string", + # "zeroshot_cfg_string", + # "fewshot_cfg_string", "ext_template_cfg_string", ], ) From 9bf365d1c9312a178386f48b1b7b202d1f381a62 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 3 Nov 2023 17:44:39 +0100 Subject: [PATCH 19/51] Start work on doc-to-shard property copying. --- spacy_llm/tasks/rel/parser.py | 6 +++--- spacy_llm/tasks/rel/task.py | 5 ++--- spacy_llm/tasks/util/sharding.py | 13 +++++++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/spacy_llm/tasks/rel/parser.py b/spacy_llm/tasks/rel/parser.py index 3f79b31e..27ede457 100644 --- a/spacy_llm/tasks/rel/parser.py +++ b/spacy_llm/tasks/rel/parser.py @@ -19,13 +19,13 @@ def parse_responses_v1( """ for responses_for_doc, shards_for_doc in zip(responses, shards): results_for_doc: List[List[RelationItem]] = [] - for response, doc in zip(responses_for_doc, shards_for_doc): + for response, shard in zip(responses_for_doc, shards_for_doc): relations: List[RelationItem] = [] for line in response.strip().split("\n"): try: rel_item = RelationItem.parse_raw(line) - if 0 <= rel_item.dep < len(doc.ents) and 0 <= rel_item.dest < len( - doc.ents + if 0 <= rel_item.dep < len(shard.ents) and 0 <= rel_item.dest < len( + shard.ents ): relations.append(rel_item) except ValidationError: diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index 8bf38f3c..dd42b246 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -101,12 +101,11 @@ def parse_responses( for shards_for_doc, rel_items_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): - updated_shards_for_doc: List[Doc] = [] + shards_for_doc = list(shards_for_doc) for shard, rel_items in zip(shards_for_doc, rel_items_for_doc): shard._.rel = rel_items - updated_shards_for_doc.append(shard) - yield self._shard_reducer(updated_shards_for_doc) + yield self._shard_reducer(shards_for_doc) def initialize( self, diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index c153f170..c6919a47 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -28,8 +28,8 @@ def make_shard_mapper( n_token_estimator (NTokenEstimator): Estimates number of tokens in a string. buffer_frac (float): Buffer to consider in assessment of whether prompt fits into context. E. g. if value is 1.1, prompt length * 1.1 will be compared with the context length. - # todo sharding would be better with sentences instead of tokens, but this requires some form of sentence - # splitting we can't rely one...maybe checking for sentences and/or as optional arg? + todo sharding would be better with sentences instead of tokens, but this requires some form of sentence + splitting we can't rely one...maybe checking for sentences and/or as optional arg? RETURNS (ShardMapper): Callable mapping doc to doc shards fitting within context length. """ n_tok_est: NTokenEstimator = n_token_estimator or make_n_token_estimator() @@ -54,6 +54,9 @@ def map_doc_to_shards( fraction = 0.5 start_idx = 0 + if n_tok_est(render_template(doc)) * buffer_frac <= context_length: + return [doc] + while remaining_doc is not None: fits_in_context = False shard: Optional[Doc] = None @@ -68,6 +71,12 @@ def map_doc_to_shards( ) fraction /= 2 + # todo doc properties, such as .ents, have to be included for some tasks (e. g. REL, EL) to work. how + # should this be done in cases where the properties transcend shard limits? + # - should sharding never cut across entities/other properties? + # - should entities or all other properties be dropped if they transcend shard properties? this seems + # like the most pragmatic solution for now. + # - which properties should be copied to shards other than .ents? assert shard is not None shards.append(shard) fraction = 1 From dddfaabdcf3b805e40e28d811dd5e56a940908fa Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 6 Nov 2023 16:24:37 +0100 Subject: [PATCH 20/51] Fix REL doc preprocessing. --- spacy_llm/tasks/rel/task.py | 22 +++++++++++++++++++--- spacy_llm/tasks/util/sharding.py | 3 --- spacy_llm/tests/tasks/test_rel.py | 2 +- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index dd42b246..b49c0486 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from spacy.language import Language -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from spacy.training import Example from ...compat import Self @@ -61,7 +61,24 @@ def __init__( self._field = "rel" def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: - return [Doc(doc.vocab, words=RELTask._preannotate(doc).split()) for doc in docs] + preprocessed_docs: List[Doc] = [] + + for doc in docs: + preprocessed_docs.append( + Doc(doc.vocab, words=RELTask._preannotate(doc).split()) + ) + preprocessed_docs[-1].ents = [ + Span( + preprocessed_docs[-1], + ent.start, + ent.end, + label=ent.label_, + kb_id=ent.kb_id_, + ) + for ent in doc.ents + ] + + return preprocessed_docs def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: return { @@ -97,7 +114,6 @@ def parse_responses( ) -> Iterable[Doc]: self._check_extension(self._field) shards_teed = tee(shards, 2) - for shards_for_doc, rel_items_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index c6919a47..04ffac45 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -54,9 +54,6 @@ def map_doc_to_shards( fraction = 0.5 start_idx = 0 - if n_tok_est(render_template(doc)) * buffer_frac <= context_length: - return [doc] - while remaining_doc is not None: fits_in_context = False shard: Optional[Doc] = None diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index 48ac0eed..79103b2e 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -135,7 +135,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string"]) # "zeroshot_cfg_string", +@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string", "zeroshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable From 3af21b5b23e19e1b940533abebab1b31e80f5d84 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 6 Nov 2023 16:38:03 +0100 Subject: [PATCH 21/51] Remove comment on doc attribute handling during sharding, as this is done by spaCy's slicing directly. --- spacy_llm/tasks/util/sharding.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index 04ffac45..d1bedc41 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -68,12 +68,6 @@ def map_doc_to_shards( ) fraction /= 2 - # todo doc properties, such as .ents, have to be included for some tasks (e. g. REL, EL) to work. how - # should this be done in cases where the properties transcend shard limits? - # - should sharding never cut across entities/other properties? - # - should entities or all other properties be dropped if they transcend shard properties? this seems - # like the most pragmatic solution for now. - # - which properties should be copied to shards other than .ents? assert shard is not None shards.append(shard) fraction = 1 From fee9ca7def37d7ba497668bcea3231ffc8dc7394 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 8 Nov 2023 12:43:48 +0100 Subject: [PATCH 22/51] Add reducer implementations. --- spacy_llm/tasks/builtin_task.py | 15 +++++++++++---- spacy_llm/tasks/lemma/util.py | 4 ++-- spacy_llm/tasks/rel/task.py | 2 +- spacy_llm/tasks/sentiment/util.py | 3 ++- spacy_llm/tasks/span/task.py | 2 +- spacy_llm/tasks/summarization/task.py | 8 ++++++-- spacy_llm/tasks/summarization/util.py | 5 +++-- spacy_llm/tasks/textcat/task.py | 2 +- spacy_llm/tasks/util/sharding.py | 7 ++++--- spacy_llm/ty.py | 2 +- 10 files changed, 32 insertions(+), 18 deletions(-) diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index de5e185b..4902bcb4 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -1,4 +1,5 @@ import abc +from itertools import tee from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, cast @@ -66,16 +67,17 @@ def generate_prompts( environment = jinja2.Environment() _template = environment.from_string(self._template) - def render_template(shard: Doc) -> str: + def render_template(shard: Doc, n_shards: int) -> str: """Renders template for a given doc (shard). shard (Doc): Doc shard. Note that if the prompt is small enough to fit within the model's context window, there will only be one shard, which is identical to the original doc. + n_shards (int): Total number of shards. RETURNS (str): Rendered template. """ return _template.render( text=doc.text, prompt_examples=self._prompt_examples, - **self._get_prompt_data(shard), + **self._get_prompt_data(shard, n_shards), ) for doc in self._preprocess_docs_for_prompt(docs): @@ -85,12 +87,17 @@ def render_template(shard: Doc) -> str: if context_length is not None else [doc] ) - yield [render_template(shard) for shard in shards], shards + shards_teed = tee(shards, 3) + yield [ + render_template(shard, len(list(shards_teed[0]))) + for shard in shards_teed[1] + ], shards_teed[2] - def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data returned by this might be static (i. e. the same for all doc shards) or dynamic (contingent on the doc shard). shard (Doc): Doc (shard) for which prompt data should be fetched. + n_shards (int): Total number of shards. RETURNS (Dict[str, Any]): Data injected into prompt template. """ return {} diff --git a/spacy_llm/tasks/lemma/util.py b/spacy_llm/tasks/lemma/util.py index c04c2696..d906b3db 100644 --- a/spacy_llm/tasks/lemma/util.py +++ b/spacy_llm/tasks/lemma/util.py @@ -32,5 +32,5 @@ def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + # Lemmas are token-specific, so we can just merge docs. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index b49c0486..b6e4d704 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -80,7 +80,7 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: return preprocessed_docs - def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 45e309ee..2654e2a6 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -25,5 +25,6 @@ def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. + # todo make generic, pass task to shard reducers (necessary e. g. for tasks with dynamic fields in docs) + # todo Weight-average sentiment scores over shards based on token count per shard. return list(shards)[0] diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index c313bfc7..f737b086 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -71,7 +71,7 @@ def __init__( if self._prompt_examples: self._prompt_examples = list(self._check_label_consistency(self)) - def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: return { "description": self._description, "labels": list(self._label_dict.values()), diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 5234d2c6..76274198 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -85,12 +85,16 @@ def _check_prompt_example_summary_len(self) -> None: f"LLM will likely produce responses that are too long." ) - def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: if self._check_example_summaries: self._check_prompt_example_summary_len() self._check_example_summaries = False - return {"max_n_words": self._max_n_words} + return { + "max_n_words": int(self._max_n_words / n_shards) + if self._max_n_words is not None + else None + } def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] diff --git a/spacy_llm/tasks/summarization/util.py b/spacy_llm/tasks/summarization/util.py index ff68ddf6..c5f6cbd7 100644 --- a/spacy_llm/tasks/summarization/util.py +++ b/spacy_llm/tasks/summarization/util.py @@ -25,5 +25,6 @@ def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + # Summaries are per shard, so we can merge. Number of shards is considered in max. number of words. This means that + # the resulting summaries will be per shard, which should be an approximately correct summary still. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index 3370966d..b06b12ed 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -90,7 +90,7 @@ def __init__( ) self._exclusive_classes = True - def _get_prompt_data(self, shard: Doc) -> Dict[str, Any]: + def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index d1bedc41..0b264222 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -35,9 +35,9 @@ def make_shard_mapper( n_tok_est: NTokenEstimator = n_token_estimator or make_n_token_estimator() def map_doc_to_shards( - doc: Doc, context_length: int, render_template: Callable[[Doc], str] + doc: Doc, context_length: int, render_template: Callable[[Doc, int], str] ) -> Union[Iterable[Doc], Doc]: - prompt = render_template(doc) + prompt = render_template(doc, 1) # If prompt with complete doc too long: split in shards. if n_tok_est(prompt) * buffer_frac > context_length: @@ -63,7 +63,8 @@ def map_doc_to_shards( end_idx = start_idx + int(len(remaining_doc) * fraction) shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) fits_in_context = ( - n_tok_est(render_template(shard)) * buffer_frac + n_tok_est(render_template(shard, int(1 / fraction))) + * buffer_frac <= context_length ) fraction /= 2 diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index a5c5af62..0c075f3b 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -26,7 +26,7 @@ NTokenEstimator = Callable[[str], int] ShardMapper = Callable[ # Requires doc, context length and callable for rendering template from doc shard text. - [Doc, int, Callable[[Doc], str]], + [Doc, int, Callable[[Doc, int], str]], # Returns each shard as a doc. Iterable[Doc], ] From e508499c1929cf14c0d5665a5a5ca07fb1589e7d Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 14 Nov 2023 15:42:05 +0100 Subject: [PATCH 23/51] Implement outstanding task reducers. --- spacy_llm/tasks/builtin_task.py | 20 +++++-- spacy_llm/tasks/lemma/task.py | 9 ++- spacy_llm/tasks/lemma/util.py | 5 +- spacy_llm/tasks/ner/task.py | 4 +- spacy_llm/tasks/ner/util.py | 7 ++- spacy_llm/tasks/rel/__init__.py | 3 +- spacy_llm/tasks/rel/examples.py | 31 ---------- spacy_llm/tasks/rel/items.py | 19 ++++++ spacy_llm/tasks/rel/registry.py | 3 +- spacy_llm/tasks/rel/task.py | 11 ++-- spacy_llm/tasks/rel/util.py | 79 +++++++++++++++++-------- spacy_llm/tasks/sentiment/task.py | 9 ++- spacy_llm/tasks/sentiment/util.py | 28 +++++++-- spacy_llm/tasks/span/task.py | 8 +-- spacy_llm/tasks/spancat/task.py | 4 +- spacy_llm/tasks/spancat/util.py | 7 ++- spacy_llm/tasks/summarization/task.py | 9 ++- spacy_llm/tasks/summarization/util.py | 3 +- spacy_llm/tasks/textcat/task.py | 9 ++- spacy_llm/tasks/textcat/util.py | 22 +++++-- spacy_llm/tests/tasks/test_sentiment.py | 5 +- spacy_llm/ty.py | 10 +++- 22 files changed, 185 insertions(+), 120 deletions(-) delete mode 100644 spacy_llm/tasks/rel/examples.py create mode 100644 spacy_llm/tasks/rel/items.py diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 4902bcb4..3527d509 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -36,7 +36,7 @@ def __init__( template: str, prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], ): """Initializes task. parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. @@ -44,7 +44,7 @@ def __init__( template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. """ self._parse_responses = parse_responses self._prompt_examples = prompt_examples or [] @@ -265,6 +265,18 @@ def _check_extension(cls, extension: str) -> None: if not Doc.has_extension(extension): Doc.set_extension(extension, default=[]) + @staticmethod + def _tee_2d_iterable( + data: Iterable[Iterable[Any]], n: int + ) -> Tuple[Iterable[List[Doc]], ...]: + """Tees two-dimensional Iterable. As Iterables in the nested iterables get consumed with the first access, we + need to materialize them - this is done by converting them to a list. + data (Iterable[Iterable[Any]]): Data to tee. + n (int): Number of tees to return. + RETURNS (Tuple[Iterable[List[Doc]], ...]): n-sized tuple of Iterables with inner Iterables converted to Lists. + """ + return tee((list(inner_data) for inner_data in data), n) + class BuiltinTaskWithLabels(BuiltinTask, abc.ABC): """Built-in tasks with labels.""" @@ -276,7 +288,7 @@ def __init__( template: str, prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], labels: List[str], label_definitions: Optional[Dict[str, str]], normalizer: Optional[Callable[[str], str]], @@ -288,7 +300,7 @@ def __init__( template (str): Prompt template passed to the model. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. labels (List[str]): List of labels to pass to the template. Leave empty to (optionally) populate it at initialization time. label_definitions (Optional[Dict[str, str]]): Map of label -> description diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index 56e4e43b..c24d82c1 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -1,4 +1,3 @@ -from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy import Language @@ -21,7 +20,7 @@ def __init__( prompt_examples: Optional[List[FewshotExample[Self]]], template: str, shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], scorer: Scorer, ): """Default lemmatization task. @@ -31,7 +30,7 @@ def __init__( prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. """ super().__init__( @@ -47,7 +46,7 @@ def __init__( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) for shards_for_doc, lemmas_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): @@ -67,7 +66,7 @@ def parse_responses( updated_shards_for_doc.append(shard) - yield self._shard_reducer(updated_shards_for_doc) + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def initialize( self, diff --git a/spacy_llm/tasks/lemma/util.py b/spacy_llm/tasks/lemma/util.py index d906b3db..a77f8507 100644 --- a/spacy_llm/tasks/lemma/util.py +++ b/spacy_llm/tasks/lemma/util.py @@ -27,10 +27,11 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return Scorer.score_token_attr(examples, "lemma") -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +def reduce_shards_to_doc(task: LemmaTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for LemmaTask. + task (LemmaTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # Lemmas are token-specific, so we can just merge docs. + # Lemmas are token-specific, so we can just merge shards. return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/ner/task.py b/spacy_llm/tasks/ner/task.py index 1577c536..af5f7892 100644 --- a/spacy_llm/tasks/ner/task.py +++ b/spacy_llm/tasks/ner/task.py @@ -26,7 +26,7 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -43,7 +43,7 @@ def __init__( parse_responses (TaskResponseParser[SpanTask]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. label_definitions (Optional[Dict[str, str]]): Map of label -> description of the label to help the language model output the entities wanted. It is usually easier to provide these definitions rather than diff --git a/spacy_llm/tasks/ner/util.py b/spacy_llm/tasks/ner/util.py index 4bc8d09d..b1ce44a2 100644 --- a/spacy_llm/tasks/ner/util.py +++ b/spacy_llm/tasks/ner/util.py @@ -38,10 +38,11 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: return get_ner_prf(examples) -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +def reduce_shards_to_doc(task: NERTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for NERTask. + task (NERTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + # NERTask only affects span-specific information, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/rel/__init__.py b/spacy_llm/tasks/rel/__init__.py index f35171a4..324126a8 100644 --- a/spacy_llm/tasks/rel/__init__.py +++ b/spacy_llm/tasks/rel/__init__.py @@ -1,7 +1,6 @@ -from .examples import RELExample from .registry import make_rel_task from .task import DEFAULT_REL_TEMPLATE, RELTask -from .util import RelationItem +from .util import RelationItem, RELExample __all__ = [ "DEFAULT_REL_TEMPLATE", diff --git a/spacy_llm/tasks/rel/examples.py b/spacy_llm/tasks/rel/examples.py deleted file mode 100644 index 3467976c..00000000 --- a/spacy_llm/tasks/rel/examples.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import List, Optional - -from spacy.training import Example - -from ...compat import Self -from ...ty import FewshotExample -from .task import RELTask -from .util import EntityItem, RelationItem - - -class RELExample(FewshotExample[RELTask]): - text: str - ents: List[EntityItem] - relations: List[RelationItem] - - @classmethod - def generate(cls, example: Example, task: RELTask) -> Optional[Self]: - entities = [ - EntityItem( - start_char=ent.start_char, - end_char=ent.end_char, - label=ent.label_, - ) - for ent in example.reference.ents - ] - - return cls( - text=example.reference.text, - ents=entities, - relations=example.reference._.rel, - ) diff --git a/spacy_llm/tasks/rel/items.py b/spacy_llm/tasks/rel/items.py new file mode 100644 index 00000000..7426d8b8 --- /dev/null +++ b/spacy_llm/tasks/rel/items.py @@ -0,0 +1,19 @@ +from ...compat import BaseModel, validator + + +class RelationItem(BaseModel): + dep: int + dest: int + relation: str + + @validator("dep", "dest", pre=True) + def clean_ent(cls, value): + if isinstance(value, str): + value = value.strip("ENT") + return value + + +class EntityItem(BaseModel): + start_char: int + end_char: int + label: str diff --git a/spacy_llm/tasks/rel/registry.py b/spacy_llm/tasks/rel/registry.py index f7142255..2399b65d 100644 --- a/spacy_llm/tasks/rel/registry.py +++ b/spacy_llm/tasks/rel/registry.py @@ -5,10 +5,9 @@ from ...ty import TaskResponseParser from ...util import split_labels from ..util.sharding import make_shard_mapper -from .examples import RELExample from .parser import parse_responses_v1 from .task import DEFAULT_REL_TEMPLATE, RELTask -from .util import reduce_shards_to_doc +from .util import RELExample, reduce_shards_to_doc @registry.llm_misc("spacy.RELShardReducer.v1") diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index b6e4d704..d7ebae8b 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -1,4 +1,3 @@ -from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union from spacy.language import Language @@ -9,7 +8,7 @@ from ...ty import FewshotExample, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTaskWithLabels from ..templates import read_template -from .util import EntityItem, RelationItem +from .items import EntityItem, RelationItem DEFAULT_REL_TEMPLATE: str = read_template("rel.v1") @@ -24,7 +23,7 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], verbose: bool, ): @@ -53,7 +52,7 @@ def __init__( prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. verbose (bool): Controls the verbosity of the task. """ @@ -113,7 +112,7 @@ def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_extension(self._field) - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) for shards_for_doc, rel_items_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): @@ -121,7 +120,7 @@ def parse_responses( for shard, rel_items in zip(shards_for_doc, rel_items_for_doc): shard._.rel = rel_items - yield self._shard_reducer(shards_for_doc) + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] def initialize( self, diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index 956d4233..68392ca6 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -1,32 +1,59 @@ -from typing import Iterable +import warnings +from typing import Iterable, List, Optional from spacy.tokens import Doc - -from ...compat import BaseModel, validator - - -class RelationItem(BaseModel): - dep: int - dest: int - relation: str - - @validator("dep", "dest", pre=True) - def clean_ent(cls, value): - if isinstance(value, str): - value = value.strip("ENT") - return value - - -class EntityItem(BaseModel): - start_char: int - end_char: int - label: str - - -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +from spacy.training import Example + +from ...compat import Self +from ...ty import FewshotExample +from .items import EntityItem, RelationItem +from .task import RELTask + + +class RELExample(FewshotExample[RELTask]): + text: str + ents: List[EntityItem] + relations: List[RelationItem] + + @classmethod + def generate(cls, example: Example, task: RELTask) -> Optional[Self]: + entities = [ + EntityItem( + start_char=ent.start_char, + end_char=ent.end_char, + label=ent.label_, + ) + for ent in example.reference.ents + ] + + return cls( + text=example.reference.text, + ents=entities, + relations=example.reference._.rel, + ) + + +def reduce_shards_to_doc(task: RELTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for RELTask. + task (RELTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + shards = list(shards) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=f".*Skipping Doc custom extension '{task.field}' while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + + # REL information from shards can be simply appended. + setattr( + doc._, + task.field, + [rel_items for shard in shards for rel_items in getattr(shard._, task.field)], + ) + + return doc diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 10c5f973..095d78a2 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -1,4 +1,3 @@ -from itertools import tee from typing import Callable, Iterable, List, Optional, Type from spacy.language import Language @@ -21,7 +20,7 @@ def __init__( field: str, prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], ): """Sentiment analysis task. @@ -32,7 +31,7 @@ def __init__( prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. """ super().__init__( parse_responses=parse_responses, @@ -72,7 +71,7 @@ def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: self._check_doc_extension() - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) for shards_for_doc, scores_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) @@ -84,7 +83,7 @@ def parse_responses( except ValueError: setattr(shard._, self._field, None) - yield self._shard_reducer(shards_for_doc) + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 2654e2a6..b323cdae 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -1,3 +1,4 @@ +import warnings from typing import Iterable, Optional from spacy.tokens import Doc @@ -20,11 +21,28 @@ def generate(cls, example: Example, task: SentimentTask) -> Optional[Self]: ) -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: - """Reduces shards to docs for SentimentTask. +def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for SentimentTask by computing an average sentiment score weighted by shard lengths. + task (SentimentTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo make generic, pass task to shard reducers (necessary e. g. for tasks with dynamic fields in docs) - # todo Weight-average sentiment scores over shards based on token count per shard. - return list(shards)[0] + shards = list(shards) + weights = [len(shard) for shard in shards] + weights = [n_tokens / sum(weights) for n_tokens in weights] + sent_scores = [getattr(shard._, task.field) for shard in shards] + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=f".*Skipping Doc custom extension '{task.field}' while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + setattr( + doc._, + task.field, + sum([score * weight for score, weight in zip(sent_scores, weights)]), + ) + + return doc diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index f737b086..56b46724 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -1,5 +1,4 @@ import abc -from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type, TypeVar, Union from typing import cast @@ -35,7 +34,7 @@ def __init__( Union[List[SpanExample[Self]], List[SpanCoTExample[Self]]] ], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], description: Optional[str], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], # noqa: F821 @@ -103,7 +102,8 @@ def assign_spans( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) + for shards_for_doc, spans_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): @@ -111,7 +111,7 @@ def parse_responses( for shard, spans in zip(shards_for_doc, spans_for_doc): self.assign_spans(shard, spans) - yield self._shard_reducer(shards_for_doc) + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/spancat/task.py b/spacy_llm/tasks/spancat/task.py index a7e9695d..b25f39e4 100644 --- a/spacy_llm/tasks/spancat/task.py +++ b/spacy_llm/tasks/spancat/task.py @@ -26,7 +26,7 @@ def __init__( spans_key: str, prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], alignment_mode: Literal["strict", "contract", "expand"], case_sensitive_matching: bool, @@ -49,7 +49,7 @@ def __init__( spans_key (str): Key of the `Doc.spans` dict to save under. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): optional normalizer function. alignment_mode (str): "strict", "contract" or "expand". case_sensitive_matching (bool): Whether to search without case sensitivity. diff --git a/spacy_llm/tasks/spancat/util.py b/spacy_llm/tasks/spancat/util.py index c83e5fe3..23eec817 100644 --- a/spacy_llm/tasks/spancat/util.py +++ b/spacy_llm/tasks/spancat/util.py @@ -44,10 +44,11 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: ) -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +def reduce_shards_to_doc(task: SpanCatTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for SpanCatTask. + task (SpanCatTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + # SpanCatTask only affects span-specific information, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index 76274198..f144b302 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -1,5 +1,4 @@ import warnings -from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language @@ -21,7 +20,7 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], template: str, shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], max_n_words: Optional[int], field: str, prompt_examples: Optional[List[FewshotExample[Self]]], @@ -32,7 +31,7 @@ def __init__( parse_responses (TaskResponseParser[Self]): Callable for parsing LLM responses for this task. prompt_example_type (Type[FewshotExample[Self]): Type to use for fewshot examples. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. max_n_words (Optional[int]): Max. number of words to use in summary. field (str): The name of the doc extension in which to store the summary. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. @@ -99,7 +98,7 @@ def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) for shards_for_doc, summaries_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) @@ -108,7 +107,7 @@ def parse_responses( for shard, summary in zip(shards_for_doc, summaries_for_doc): setattr(shard._, self._field, summary) - yield self._shard_reducer(shards_for_doc) + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] @property def _cfg_keys(self) -> List[str]: diff --git a/spacy_llm/tasks/summarization/util.py b/spacy_llm/tasks/summarization/util.py index c5f6cbd7..4ef711af 100644 --- a/spacy_llm/tasks/summarization/util.py +++ b/spacy_llm/tasks/summarization/util.py @@ -20,8 +20,9 @@ def generate(cls, example: Example, task: SummarizationTask) -> Optional[Self]: ) -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +def reduce_shards_to_doc(task: SummarizationTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for SummarizationTask. + task (SummarizationTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index b06b12ed..ef29d79e 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -1,4 +1,3 @@ -from itertools import tee from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language @@ -26,7 +25,7 @@ def __init__( label_definitions: Optional[Dict[str, str]], prompt_examples: Optional[List[FewshotExample[Self]]], shard_mapper: ShardMapper, - shard_reducer: ShardReducer, + shard_reducer: ShardReducer[Self], normalizer: Optional[Callable[[str], str]], exclusive_classes: bool, allow_none: bool, @@ -57,7 +56,7 @@ def __init__( These descriptions are added to the prompt to help instruct the LLM on what to extract. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. - shard_reducer (ShardReducer): Reduces doc shards back into one doc instance. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. normalizer (Optional[Callable[[str], str]]): Optional normalizer function. exclusive_classes (bool): If True, require the language model to suggest only one label per class. This is automatically set when using binary classification. @@ -101,7 +100,7 @@ def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - shards_teed = tee(shards, 2) + shards_teed = self._tee_2d_iterable(shards, 2) for shards_for_doc, cats_for_doc in zip( shards_teed[0], self._parse_responses(self, shards_teed[1], responses) ): @@ -111,7 +110,7 @@ def parse_responses( shard.cats = cats updated_shards_for_doc.append(shard) - yield self._shard_reducer(updated_shards_for_doc) + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def scorer( self, diff --git a/spacy_llm/tasks/textcat/util.py b/spacy_llm/tasks/textcat/util.py index 7c4de62d..fd69d4b7 100644 --- a/spacy_llm/tasks/textcat/util.py +++ b/spacy_llm/tasks/textcat/util.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Iterable, Optional +from collections import defaultdict +from typing import Any, DefaultDict, Dict, Iterable, Optional from spacy.scorer import Scorer from spacy.tokens import Doc @@ -49,10 +50,23 @@ def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: ) -def reduce_shards_to_doc(shards: Iterable[Doc]) -> Doc: +def reduce_shards_to_doc(task: TextCatTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for TextCatTask. + task (TextCatTask): Task. shards (Iterable[Doc]): Shards to reduce to single doc instance. RETURNS (Doc): Fused doc instance. """ - # todo this is yet a dummy implementation that will only return the first doc shard. - return list(shards)[0] + shards = list(shards) + + # Compute average sum per category weighted by shard length. + weights = [len(shard) for shard in shards] + weights = [n_tokens / sum(weights) for n_tokens in weights] + all_cats: DefaultDict[str, float] = defaultdict(lambda: 0) + for weight, shard in zip(weights, shards): + for cat, cat_score in shard.cats.items(): + all_cats[cat] += cat_score * weight + + doc = Doc.from_docs(shards, ensure_whitespace=True) + doc.cats = all_cats + + return doc diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index 034e09e4..a15b192f 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -129,9 +129,10 @@ def test_sentiment_predict(cfg_string, request): orig_config = Config().from_str(cfg) nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) if cfg_string != "ext_template_cfg_string": - assert nlp("This is horrible.")._.sentiment == 0 + # with pytest.warns() as record: + assert nlp("This is horrible.")._.sentiment == 0.0 assert 0 < nlp("This is meh.")._.sentiment <= 0.5 - assert nlp("This is perfect.")._.sentiment == 1 + assert nlp("This is perfect.")._.sentiment == 1.0 @pytest.mark.external diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 0c075f3b..2f3f9506 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -30,7 +30,6 @@ # Returns each shard as a doc. Iterable[Doc], ] -ShardReducer = Callable[[Iterable[Doc]], Doc] @runtime_checkable @@ -123,6 +122,15 @@ def parse_responses( TaskContraT = TypeVar("TaskContraT", bound=LLMTask, contravariant=True) +@runtime_checkable +class ShardReducer(Protocol[TaskContraT]): + """Generic protocol for tasks' shard reducer.""" + + def __call__(self, task: TaskContraT, shards: Iterable[Doc]) -> Doc: + """Merges shard to single Doc.""" + ... + + class FewshotExample(GenericModel, abc.ABC, Generic[TaskContraT]): @classmethod @abc.abstractmethod From c104387059c18d0d79f8808ac50577e22b36998a Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 20 Nov 2023 17:26:38 +0100 Subject: [PATCH 24/51] Add shardable/non-shardable LLM task typing distinction. Add support for handling both types of tasks. Update tests. --- spacy_llm/cache.py | 6 +- spacy_llm/pipeline/llm.py | 64 ++++++++++----- spacy_llm/tasks/__init__.py | 4 +- spacy_llm/tasks/builtin_task.py | 6 +- spacy_llm/tasks/noop.py | 26 +++++- spacy_llm/tests/pipeline/test_llm.py | 14 +++- spacy_llm/tests/tasks/legacy/test_ner.py | 4 +- spacy_llm/tests/tasks/legacy/test_spancat.py | 4 +- spacy_llm/tests/tasks/test_ner.py | 4 +- spacy_llm/tests/tasks/test_rel.py | 4 +- spacy_llm/tests/tasks/test_spancat.py | 4 +- spacy_llm/tests/tasks/test_summarization.py | 4 +- spacy_llm/tests/tasks/test_textcat.py | 4 +- spacy_llm/ty.py | 85 ++++++++++++++++---- 14 files changed, 168 insertions(+), 65 deletions(-) diff --git a/spacy_llm/cache.py b/spacy_llm/cache.py index 3e1559e1..cd6bd8a9 100644 --- a/spacy_llm/cache.py +++ b/spacy_llm/cache.py @@ -8,7 +8,7 @@ from spacy.vocab import Vocab from .registry import registry -from .ty import LLMTask, PromptTemplateProvider +from .ty import PromptTemplateProvider, ShardingLLMTask @registry.llm_misc("spacy.BatchCache.v1") @@ -68,11 +68,11 @@ def __init__( self._init_cache_dir() - def initialize(self, vocab: Vocab, task: LLMTask) -> None: + def initialize(self, vocab: Vocab, task: ShardingLLMTask) -> None: """ Initialize cache with data not available at construction time. vocab (Vocab): Vocab object. - task (LLMTask): Task. + task (ShardingLLMTask): Task. """ self._vocab = vocab if isinstance(task, PromptTemplateProvider): diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 267c8981..f4a76911 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -16,8 +16,8 @@ from .. import registry # noqa: F401 from ..compat import TypedDict from ..ty import Cache, LabeledTask, LLMTask, ModelWithContextLength -from ..ty import PromptExecutorType, ScorableTask, Serializable -from ..ty import validate_type_consistency +from ..ty import PromptExecutorType, ScorableTask, Serializable, ShardingLLMTask +from ..ty import supports_sharding, validate_type_consistency logger = logging.getLogger("spacy_llm") logger.addHandler(logging.NullHandler()) @@ -35,6 +35,7 @@ DEFAULT_SAVE_IO = False DEFAULT_VALIDATE_TYPES = True +_LLMTask = Union[LLMTask, ShardingLLMTask] class CacheConfigType(TypedDict): @@ -58,7 +59,7 @@ class CacheConfigType(TypedDict): def make_llm( nlp: Language, name: str, - task: Optional[LLMTask], + task: Optional[_LLMTask], model: PromptExecutorType, cache: Cache, save_io: bool, @@ -69,7 +70,7 @@ def make_llm( nlp (Language): Pipeline. name (str): The component instance name, used to add entries to the losses during training. - task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses into + task (Optional[_LLMTask]): An _LLMTask can generate prompts for given docs, and can parse the LLM's responses into structured information and set that back on the docs. model (Callable[[Iterable[Any]], Iterable[Any]]]): Callable querying the specified LLM API. cache (Cache): Cache to use for caching prompts and responses per doc (batch). @@ -102,7 +103,7 @@ def __init__( name: str = "LLMWrapper", *, vocab: Vocab, - task: LLMTask, + task: _LLMTask, model: PromptExecutorType, cache: Cache, save_io: bool, @@ -113,8 +114,8 @@ def __init__( name (str): The component instance name, used to add entries to the losses during training. vocab (Vocab): Pipeline vocabulary. - task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses into - structured information and set that back on the docs. + task (Optional[_LLMTask]): An _LLMTask can generate prompts for given docs, and can parse the LLM's responses + into structured information and set that back on the docs. model (Callable[[Iterable[Any]], Iterable[Any]]]): Callable querying the specified LLM API. cache (Cache): Cache to use for caching prompts and responses per doc (batch). save_io (bool): Whether to save LLM I/O (prompts and responses) in the `Doc._.llm_io` custom extension. @@ -150,7 +151,7 @@ def clear(self) -> None: return self._task.clear() @property - def task(self) -> LLMTask: + def task(self) -> _LLMTask: return self._task def __call__(self, doc: Doc) -> Doc: @@ -198,6 +199,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: docs (List[Doc]): Input batch of docs RETURNS (List[Doc]): Processed batch of docs with task annotations set """ + has_shards = supports_sharding(self._task) is_cached = [doc in self._cache for doc in docs] noncached_doc_batch = [doc for i, doc in enumerate(docs) if not is_cached[i]] if len(noncached_doc_batch) < len(docs): @@ -214,27 +216,42 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: if isinstance(self._model, ModelWithContextLength): context_length = self._model.context_length + # Only pass context length if this is a sharding task. prompts_iters = tee( - self._task.generate_prompts(noncached_doc_batch, context_length), + self._task.generate_prompts(noncached_doc_batch, context_length) # type: ignore[call-arg] + if has_shards + else self._task.generate_prompts(noncached_doc_batch), n_iters + 1, ) responses_iters = tee( - self._model((elem[0] for elem in prompts_iters[0])), n_iters + self._model( + # Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped + # in a list to conform to the nested structure. + (elem[0] if has_shards else [elem] for elem in prompts_iters[0]) + ), + n_iters, ) - for prompts_and_shards, response, doc in zip( + for prompt_data, response, doc in zip( prompts_iters[1], responses_iters[0], noncached_doc_batch ): logger.debug( - "Generated prompt for doc: %s\n%s", doc.text, prompts_and_shards[0] + "Generated prompt for doc: %s\n%s", + doc.text, + prompt_data[0] if has_shards else prompt_data, ) logger.debug("LLM response for doc: %s\n%s", doc.text, response) - modified_docs = iter( + resp = list( self._task.parse_responses( - (elem[1] for elem in prompts_iters[2]), responses_iters[1] + ( + elem[1] if has_shards else noncached_doc_batch[i] + for i, elem in enumerate(prompts_iters[2]) + ), + responses_iters[1], ) ) + modified_docs = iter(resp) final_docs: List[Doc] = [] for i, doc in enumerate(docs): @@ -252,7 +269,10 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: "llm_io", defaultdict(dict) ) llm_io = doc.user_data["llm_io"][self._name] - llm_io["prompt"] = str(next(prompts_iters[-1])[0]) + next_prompt = next(prompts_iters[-1]) + llm_io["prompt"] = str( + next_prompt[0] if has_shards else next_prompt + ) llm_io["response"] = str(next(responses_iters[-1])) self._cache.add(doc) @@ -274,7 +294,7 @@ def to_bytes( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda: self._task.to_bytes(exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda: self._task.to_bytes(exclude=exclude) # type: ignore[attr-defined, union-attr] if isinstance(self._model, Serializable): serialize["model"] = lambda: self._model.to_bytes(exclude=exclude) # type: ignore[attr-defined] @@ -296,9 +316,9 @@ def from_bytes( deserialize = {} if isinstance(self._task, Serializable): - deserialize["task"] = lambda b: self._task.from_bytes(b, exclude=exclude) # type: ignore[attr-defined] + deserialize["task"] = lambda b: self._task.from_bytes(b, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - deserialize["model"] = lambda b: self._model.from_bytes(b, exclude=exclude) # type: ignore[attr-defined] + deserialize["model"] = lambda b: self._model.from_bytes(b, exclude=exclude) # type: ignore[attr-defined,union-attr] util.from_bytes(bytes_data, deserialize, exclude) return self @@ -314,9 +334,9 @@ def to_disk( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda p: self._task.to_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda p: self._task.to_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - serialize["model"] = lambda p: self._model.to_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["model"] = lambda p: self._model.to_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] util.to_disk(path, serialize, exclude) @@ -332,9 +352,9 @@ def from_disk( serialize = {} if isinstance(self._task, Serializable): - serialize["task"] = lambda p: self._task.from_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["task"] = lambda p: self._task.from_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] if isinstance(self._model, Serializable): - serialize["model"] = lambda p: self._model.from_disk(p, exclude=exclude) # type: ignore[attr-defined] + serialize["model"] = lambda p: self._model.from_disk(p, exclude=exclude) # type: ignore[attr-defined,union-attr] util.from_disk(path, serialize, exclude) return self diff --git a/spacy_llm/tasks/__init__.py b/spacy_llm/tasks/__init__.py index 6551c536..e30995db 100644 --- a/spacy_llm/tasks/__init__.py +++ b/spacy_llm/tasks/__init__.py @@ -5,7 +5,7 @@ from .builtin_task import BuiltinTask from .lemma import LemmaTask, make_lemma_task from .ner import NERTask, make_ner_task_v3 -from .noop import NoopTask, make_noop_task +from .noop import NoopTask, ShardingNoopTask, make_noop_task, make_noopnoshards_task from .rel import RELTask, make_rel_task from .sentiment import SentimentTask, make_sentiment_task from .spancat import SpanCatTask, make_spancat_task_v3 @@ -39,6 +39,7 @@ "make_lemma_task", "make_ner_task_v3", "make_noop_task", + "make_noopnoshards_task", "make_rel_task", "make_sentiment_task", "make_spancat_task_v3", @@ -50,6 +51,7 @@ "NoopTask", "RELTask", "SentimentTask", + "ShardingNoopTask", "SpanCatTask", "SummarizationTask", "TextCatTask", diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 7acab421..11909162 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -23,10 +23,10 @@ class BuiltinTask(abc.ABC): - initializable (in line with other spaCy components) - (de-)serialization - On the relation of BuiltinTask to LLMTask: the latter specifies the minimal contract a task implementation + On the relation of BuiltinTask to ShardingLLMTask: the latter specifies the minimal contract a task implementation has to fulfill, whereas a BuiltinTask requires (and offers) functionality beyond that. The rationale behind that is - that built-in tasks should provide as smooth a usage experience as possible while still making it as easy as possible - for users to write their own, custom tasks. + that built-in tasks should provide as smooth a usage experience as possible while still making it as easy as + possible for users to write their own, custom tasks. """ def __init__( diff --git a/spacy_llm/tasks/noop.py b/spacy_llm/tasks/noop.py index 9d11ee64..12d101f9 100644 --- a/spacy_llm/tasks/noop.py +++ b/spacy_llm/tasks/noop.py @@ -9,10 +9,15 @@ @registry.llm_tasks("spacy.NoOp.v1") def make_noop_task(): + return ShardingNoopTask() + + +@registry.llm_tasks("spacy.NoOpNoShards.v1") +def make_noopnoshards_task(): return NoopTask() -class NoopTask: +class ShardingNoopTask: def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None ) -> Iterable[Tuple[Iterable[str], Iterable[Doc]]]: @@ -22,7 +27,6 @@ def generate_prompts( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - # Grab the first shard per doc return [list(shards_for_doc)[0] for shards_for_doc in shards] @property @@ -31,3 +35,21 @@ def prompt_template(self) -> str: This is the NoOp prompt template """ + + +class NoopTask: + def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: + for doc in docs: + yield _NOOP_PROMPT + + def parse_responses( + self, docs: Iterable[Doc], responses: Iterable[str] + ) -> Iterable[Doc]: + return docs + + @property + def prompt_template(self) -> str: + return """ + This is the NoOp + prompt template + """ diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index c2a6137d..13d6482e 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -51,11 +51,20 @@ def test_llm_init(nlp): @pytest.mark.parametrize("n_process", [1, 2]) -def test_llm_pipe(nlp: Language, n_process: int): +@pytest.mark.parametrize("shard", [True, False]) +def test_llm_pipe(noop_config: Dict[str, Any], n_process: int, shard: bool): """Test call .pipe().""" + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpNoShards.v1"}}} + if not shard + else noop_config, + ) ops = get_current_ops() if not isinstance(ops, NumpyOps) and n_process != 1: pytest.skip("Only test multiple processes on CPU") + docs = list( nlp.pipe(texts=["This is a test", "This is another test"], n_process=n_process) ) @@ -63,8 +72,7 @@ def test_llm_pipe(nlp: Language, n_process: int): for doc in docs: llm_io = doc.user_data["llm_io"] - - assert llm_io["llm"]["prompt"] == str([_NOOP_PROMPT]) + assert llm_io["llm"]["prompt"] == str([_NOOP_PROMPT] if shard else _NOOP_PROMPT) assert llm_io["llm"]["response"] == str([_NOOP_RESPONSE]) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 7b25a577..b17848d0 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -15,7 +15,7 @@ from spacy_llm.registry import strip_normalizer from spacy_llm.tasks.ner import NERTask, make_ner_task_v2 from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...compat import has_openai_key @@ -195,7 +195,7 @@ def test_ner_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = orig_config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) diff --git a/spacy_llm/tests/tasks/legacy/test_spancat.py b/spacy_llm/tests/tasks/legacy/test_spancat.py index bd6ead2f..4b70ee46 100644 --- a/spacy_llm/tests/tasks/legacy/test_spancat.py +++ b/spacy_llm/tests/tasks/legacy/test_spancat.py @@ -12,7 +12,7 @@ from spacy_llm.registry import fewshot_reader, lowercase_normalizer, strip_normalizer from spacy_llm.tasks.spancat import SpanCatTask, make_spancat_task_v2 from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...compat import has_openai_key @@ -85,7 +85,7 @@ def test_spancat_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = orig_config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index c1da7659..c4204769 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -20,7 +20,7 @@ from spacy_llm.tasks.span import SpanReason from spacy_llm.tasks.span.parser import _extract_span_reasons_cot from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -205,7 +205,7 @@ def test_ner_config(config: Config): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index 75b48996..241aacda 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -10,7 +10,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.tasks.rel import DEFAULT_REL_TEMPLATE, RelationItem, RELTask -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ...tasks import make_rel_task @@ -122,7 +122,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) task = pipe.task labels = orig_config["components"]["llm"]["task"]["labels"] diff --git a/spacy_llm/tests/tasks/test_spancat.py b/spacy_llm/tests/tasks/test_spancat.py index 1ba9c11e..eb19f6ab 100644 --- a/spacy_llm/tests/tasks/test_spancat.py +++ b/spacy_llm/tests/tasks/test_spancat.py @@ -14,7 +14,7 @@ from spacy_llm.tasks import make_spancat_task_v3 from spacy_llm.tasks.spancat import SpanCatTask from spacy_llm.tasks.util import find_substrings -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -148,7 +148,7 @@ def test_spancat_config(config: Config): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = config["components"]["llm"]["task"]["labels"] labels = split_labels(labels) diff --git a/spacy_llm/tests/tasks/test_summarization.py b/spacy_llm/tests/tasks/test_summarization.py index c51ddfb9..b3b923b1 100644 --- a/spacy_llm/tests/tasks/test_summarization.py +++ b/spacy_llm/tests/tasks/test_summarization.py @@ -8,7 +8,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.registry import fewshot_reader, file_reader -from spacy_llm.ty import LLMTask +from spacy_llm.ty import ShardingLLMTask from spacy_llm.util import assemble_from_config from ...tasks import make_summarization_task @@ -152,7 +152,7 @@ def test_summarization_config(cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) @pytest.mark.external diff --git a/spacy_llm/tests/tasks/test_textcat.py b/spacy_llm/tests/tasks/test_textcat.py index 0e9b9828..e4ddad3a 100644 --- a/spacy_llm/tests/tasks/test_textcat.py +++ b/spacy_llm/tests/tasks/test_textcat.py @@ -13,7 +13,7 @@ from spacy_llm.registry import fewshot_reader, file_reader, lowercase_normalizer from spacy_llm.registry import registry from spacy_llm.tasks.textcat import TextCatTask, make_textcat_task_v3 -from spacy_llm.ty import LabeledTask, LLMTask +from spacy_llm.ty import LabeledTask, ShardingLLMTask from spacy_llm.util import assemble_from_config, split_labels from ..compat import has_openai_key @@ -204,7 +204,7 @@ def test_textcat_config(task, cfg_string, request): pipe = nlp.get_pipe("llm") assert isinstance(pipe, LLMWrapper) - assert isinstance(pipe.task, LLMTask) + assert isinstance(pipe.task, ShardingLLMTask) labels = split_labels(labels) task = pipe.task diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 52f39ba7..d0d28bb2 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -92,8 +92,18 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: """ +# todo +# x change to llmtask +# x add llmtask +# x fix task typing structures +# x fix model data handling +# x don't expect doc back from nonsharding tasks +# x run tests with to sharding and non-sharding nooptask +# - fix inevitable typing check issues + + @runtime_checkable -class LLMTask(Protocol): +class ShardingLLMTask(Protocol): def generate_prompts( self, docs: Iterable[Doc], context_length: Optional[int] = None ) -> Iterable[Tuple[Iterable[_PromptType], Iterable[Doc]]]: @@ -119,7 +129,26 @@ def parse_responses( """ -TaskContraT = TypeVar("TaskContraT", bound=LLMTask, contravariant=True) +@runtime_checkable +class LLMTask(Protocol): + def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]: + """Generate prompts from docs. + docs (Iterable[Doc]): Docs to generate prompts from. + RETURNS (Iterable[_PromptType]): Iterable with one prompt per doc. + """ + + def parse_responses( + self, docs: Iterable[Doc], responses: Iterable[_ResponseType] + ) -> Iterable[Doc]: + """ + Parses LLM responses. + docs (Iterable[Doc]): Docs to map responses into. + responses ([Iterable[_ResponseType]]): LLM responses. + RETURNS (Iterable[Doc]]): Updated docs. + """ + + +TaskContraT = TypeVar("TaskContraT", bound=ShardingLLMTask, contravariant=True) @runtime_checkable @@ -178,11 +207,11 @@ def clear(self) -> None: class Cache(Protocol): """Defines minimal set of operations a cache implementiation needs to support.""" - def initialize(self, vocab: Vocab, task: LLMTask) -> None: + def initialize(self, vocab: Vocab, task: Union[LLMTask, ShardingLLMTask]) -> None: """ Initialize cache with data not available at construction time. vocab (Vocab): Vocab object. - task (LLMTask): Task. + task (Union[LLMTask, ShardingLLMTask]): Task. """ def add(self, doc: Doc) -> None: @@ -295,15 +324,36 @@ def _extract_model_call_signature(model: PromptExecutorType) -> Dict[str, Any]: return signature -def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: +def supports_sharding(task: Union[LLMTask, ShardingLLMTask]) -> bool: + """Determines task type, as isinstance(instance, Protocol) only checks for method names. This also considers + argument and return types. Raises an exception if task is neither. + Note that this is not as thorough as validate_type_consistency() and relies on clues to determine which task type + a given, type-validated task type is. This doesn't guarantee that a task has valid typing. This method should only + be in conjunction with validate_type_consistency(). + task (Union[LLMTask, ShardingLLMTask]): Task to check. + RETURNS (bool): True if task supports sharding, False if not. + """ + prompt_ret_type = typing.get_type_hints(task.generate_prompts)["return"].__args__[0] + return ( + hasattr(prompt_ret_type, "_name") + and prompt_ret_type._name == "Tuple" + and len(prompt_ret_type.__args__) == 2 + ) + + +def validate_type_consistency( + task: Union[LLMTask, ShardingLLMTask], model: PromptExecutorType +) -> None: """Check whether the types of the task and model signatures match. - task (LLMTask): Specified task. + task (ShardingLLMTask): Specified task. model (PromptExecutor): Specified model. """ # Raises an error or prints a warning if something looks wrong/odd. + # todo update error messages if not isinstance(task, LLMTask): raise ValueError( - f"A task needs to be of type 'LLMTask' but found {type(task)} instead" + f"A task needs to adhere to the interface of either 'LLMTask' or 'ShardingLLMTask', but {type(task)} " + f"doesn't." ) if not hasattr(task, "generate_prompts"): raise ValueError( @@ -368,29 +418,30 @@ def validate_type_consistency(task: LLMTask, model: PromptExecutorType) -> None: # Ensure that template/prompt generator output is Iterable of 2-Tuple, the second of which fits doc shards type. template_out_type = template_out.__args__[0] - if not ( + if ( hasattr(template_out_type, "_name") and template_out_type._name == "Tuple" and len(template_out_type.__args__) == 2 ): - warnings.warn( - f"Type in `Iterable` returned from `task.generate_prompts()` (`{template_out_type}`) has to be a 2-tuple " - f"(prompts, doc shards)." - ) - template_out_prompt_type = template_out_type.__args__[0] + has_shards = True + template_out_type = template_out_type.__args__[0] + else: + has_shards = False # Ensure that the template returns the same type as expected by the model assert model_in is not None if not _do_args_match( - template_out_prompt_type, model_in.__args__[0], 1 + template_out_type if has_shards else typing.Iterable[template_out_type], # type: ignore[valid-type] + model_in.__args__[0], + 1, ): # type: ignore[arg-type] warnings.warn( - f"First type in `Iterable[Tuple[...]] returned from `task.generate_prompts()` " - f"(`{template_out_prompt_type}`) doesn't match type expected by `model` (`{model_in.__args__[0]}`)." + f"First type in value returned from `task.generate_prompts()` (`{template_out_type}`) doesn't match type " + f"expected by `model` (`{model_in.__args__[0]}`)." ) # Ensure that the parser expects the same type as returned by the model - if not _do_args_match(model_out, parse_in, 2): # type: ignore[arg-type] + if not _do_args_match(model_out, parse_in if has_shards else typing.Iterable[parse_in], 2): # type: ignore[arg-type,valid-type] warnings.warn( f"Type returned from `model` (`{model_out}`) doesn't match type expected by " f"`task.parse_responses()` (`{parse_in}`)." From 2502c4d38df5e1297fa0c87f369851463885de3d Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 23 Nov 2023 11:04:35 +0100 Subject: [PATCH 25/51] Fix EL task. --- spacy_llm/models/rest/anthropic/registry.py | 7 + spacy_llm/models/rest/azure/model.py | 4 +- spacy_llm/models/rest/azure/registry.py | 1 + spacy_llm/models/rest/base.py | 11 +- spacy_llm/models/rest/cohere/registry.py | 1 + spacy_llm/models/rest/noop/model.py | 1 + spacy_llm/models/rest/openai/registry.py | 36 ++- spacy_llm/models/rest/palm/registry.py | 5 +- spacy_llm/pipeline/llm.py | 8 + spacy_llm/tasks/builtin_task.py | 20 +- spacy_llm/tasks/entity_linker/parser.py | 67 +++--- spacy_llm/tasks/entity_linker/registry.py | 18 +- spacy_llm/tasks/entity_linker/task.py | 238 +++++++++++++------ spacy_llm/tasks/entity_linker/util.py | 20 +- spacy_llm/tasks/rel/task.py | 4 +- spacy_llm/tasks/span/task.py | 4 +- spacy_llm/tasks/summarization/task.py | 4 +- spacy_llm/tasks/textcat/task.py | 4 +- spacy_llm/tasks/util/sharding.py | 13 +- spacy_llm/tests/models/test_anthropic.py | 3 + spacy_llm/tests/models/test_cohere.py | 4 + spacy_llm/tests/models/test_langchain.py | 6 +- spacy_llm/tests/pipeline/test_llm.py | 4 +- spacy_llm/tests/tasks/legacy/test_ner.py | 3 +- spacy_llm/tests/tasks/legacy/test_spancat.py | 3 +- spacy_llm/tests/tasks/test_entity_linker.py | 10 +- spacy_llm/tests/tasks/test_textcat.py | 3 +- spacy_llm/tests/test_combinations.py | 20 +- spacy_llm/ty.py | 33 +-- 29 files changed, 398 insertions(+), 157 deletions(-) diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index 89c9157a..4598146e 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -38,6 +38,7 @@ def anthropic_claude_2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -72,6 +73,7 @@ def anthropic_claude_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -108,6 +110,7 @@ def anthropic_claude_instant_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -144,6 +147,7 @@ def anthropic_claude_instant_1_1( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -178,6 +182,7 @@ def anthropic_claude_1_0( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -212,6 +217,7 @@ def anthropic_claude_1_2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -246,4 +252,5 @@ def anthropic_claude_1_3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index 436deabe..32adc0bb 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized +from typing import Any, Dict, Iterable, List, Optional, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -27,6 +27,7 @@ def __init__( interval: float, max_request_time: float, model_type: ModelType, + context_length: Optional[int], api_version: str = "2023-05-15", ): self._model_type = model_type @@ -40,6 +41,7 @@ def __init__( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @property diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index 850cc43f..fcf6bd72 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -59,4 +59,5 @@ def azure_openai( max_request_time=max_request_time, api_version=api_version, model_type=model_type, + context_length=None, ) diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index fbdc8823..12bdefdd 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -33,6 +33,7 @@ def __init__( max_tries: int, interval: float, max_request_time: float, + context_length: Optional[int], ): """Initializes new instance of REST-based model. name (str): Model name. @@ -47,6 +48,8 @@ def __init__( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context + length natively provided by spacy-llm. """ self._name = name self._endpoint = endpoint @@ -56,6 +59,7 @@ def __init__( self._interval = interval self._max_request_time = max_request_time self._credentials = self.credentials + self._context_length = context_length assert self._max_tries >= 1 assert self._interval > 0 @@ -78,12 +82,11 @@ def _get_context_lengths() -> Dict[str, int]: """ @property - def context_length(self) -> int: + def context_length(self) -> Optional[int]: """Returns context length in number of tokens for this model. - RETURNS (int): Max. number of tokens in allowed in prompt for the current model. + RETURNS (Optional[int]): Max. number of tokens in allowed in prompt for the current model. None if unknown. """ - # todo if context length not available in dict: accept param, otherwise fail? - return self._get_context_lengths()[self._name] + return self._get_context_lengths().get(self._name, self._context_length) @property @abc.abstractmethod diff --git a/spacy_llm/models/rest/cohere/registry.py b/spacy_llm/models/rest/cohere/registry.py index 06adeedc..d2bc8997 100644 --- a/spacy_llm/models/rest/cohere/registry.py +++ b/spacy_llm/models/rest/cohere/registry.py @@ -39,4 +39,5 @@ def cohere_command( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index 7a6a2111..5364438a 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -21,6 +21,7 @@ def __init__(self): max_tries=1, interval=1, max_request_time=1, + context_length=None, ) @property diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 1ab8694c..62802687 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from confection import SimpleFrozenDict @@ -24,18 +24,21 @@ @registry.llm_models("spacy.GPT-4.v3") def openai_gpt_4_v3( config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), - name: str = "gpt-4", # noqa: F722 + name: str = "gpt-4", strict: bool = OpenAI.DEFAULT_STRICT, max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, ) -> OpenAI: """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4', "gpt-4-1106-preview", .... - RETURNS (OpenAI): OpenAI instance for 'gpt-4' model + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI instance for 'gpt-4' model. DOCS: https://spacy.io/api/large-language-models#models """ @@ -47,6 +50,7 @@ def openai_gpt_4_v3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -77,6 +81,7 @@ def openai_gpt_4_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -108,6 +113,7 @@ def openai_gpt_4( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -119,12 +125,15 @@ def openai_gpt_3_5_v3( max_tries: int = OpenAI.DEFAULT_MAX_TRIES, interval: float = OpenAI.DEFAULT_INTERVAL, max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, ) -> OpenAI: """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. name (str): Name of model to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-3.5', "gpt-3.5-turbo", .... + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. RETURNS (OpenAI): OpenAI instance for 'gpt-3.5' model DOCS: https://spacy.io/api/large-language-models#models @@ -139,6 +148,7 @@ def openai_gpt_3_5_v3( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=context_length, ) @@ -177,6 +187,7 @@ def openai_gpt_3_5_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -215,6 +226,7 @@ def openai_gpt_3_5( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -247,6 +259,7 @@ def openai_text_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -277,6 +290,7 @@ def openai_text_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -307,6 +321,7 @@ def openai_code_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -335,6 +350,7 @@ def openai_code_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -365,6 +381,7 @@ def openai_text_curie_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -393,6 +410,7 @@ def openai_text_curie( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -423,6 +441,7 @@ def openai_text_babbage_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -451,6 +470,7 @@ def openai_text_babbage( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -481,6 +501,7 @@ def openai_text_ada_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -509,6 +530,7 @@ def openai_text_ada( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -539,6 +561,7 @@ def openai_davinci_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -567,6 +590,7 @@ def openai_davinci( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -597,6 +621,7 @@ def openai_curie_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -625,6 +650,7 @@ def openai_curie( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -655,6 +681,7 @@ def openai_babbage_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -683,6 +710,7 @@ def openai_babbage( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -713,6 +741,7 @@ def openai_ada_v2( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) @@ -741,4 +770,5 @@ def openai_ada( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/models/rest/palm/registry.py b/spacy_llm/models/rest/palm/registry.py index ed2b396a..7ec8e65c 100644 --- a/spacy_llm/models/rest/palm/registry.py +++ b/spacy_llm/models/rest/palm/registry.py @@ -27,7 +27,9 @@ def palm_bison( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (PaLM): PaLM instance for Bison model. """ return PaLM( name=name, @@ -39,4 +41,5 @@ def palm_bison( max_tries=max_tries, interval=interval, max_request_time=max_request_time, + context_length=None, ) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index f4a76911..756836b6 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -1,4 +1,5 @@ import logging +import warnings from collections import defaultdict from itertools import tee from pathlib import Path @@ -215,6 +216,13 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: context_length: Optional[int] = None if isinstance(self._model, ModelWithContextLength): context_length = self._model.context_length + if has_shards and context_length is None: + warnings.warn( + "Task supports sharding, but model does not provide context length. Data won't be sharded, prompt " + "might exceed the model's context length. Set context length in your config. If you think spacy-llm" + " should provide the context length for this models automatically, report this to " + "https://github.com/explosion/spacy-llm/issues." + ) # Only pass context length if this is a sharding task. prompts_iters = tee( diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 2bb5df4c..4b9a7ddb 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -67,36 +67,42 @@ def generate_prompts( environment = jinja2.Environment() _template = environment.from_string(self._template) - def render_template(shard: Doc, n_shards: int) -> str: + def render_template(shard: Doc, i_shard: int, i_doc: int, n_shards: int) -> str: """Renders template for a given doc (shard). shard (Doc): Doc shard. Note that if the prompt is small enough to fit within the model's context window, there will only be one shard, which is identical to the original doc. + i_shard (int): Shard index (w.r.t. shard's Doc instance). + i_doc (int): Doc index. n_shards (int): Total number of shards. RETURNS (str): Rendered template. """ return _template.render( text=doc.text, prompt_examples=self._prompt_examples, - **self._get_prompt_data(shard, n_shards), + **self._get_prompt_data(shard, i_shard, i_doc, n_shards), ) - for doc in self._preprocess_docs_for_prompt(docs): + for i_doc, doc in enumerate(self._preprocess_docs_for_prompt(docs)): # If no context length provided (e. g. because models don't provide it): don't shard. shards = ( - self._shard_mapper(doc, context_length, render_template) + self._shard_mapper(doc, i_doc, context_length, render_template) if context_length is not None else [doc] ) shards_teed = tee(shards, 3) yield [ - render_template(shard, len(list(shards_teed[0]))) - for shard in shards_teed[1] + render_template(shard, i_shard, i_doc, len(list(shards_teed[0]))) + for i_shard, shard in enumerate(shards_teed[1]) ], shards_teed[2] - def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: """Returns data injected into prompt template. No-op if not overridden by inheriting task class. The data returned by this might be static (i. e. the same for all doc shards) or dynamic (contingent on the doc shard). shard (Doc): Doc (shard) for which prompt data should be fetched. + i_shard (int): Shard index (w.r.t. shard's Doc instance). + i_doc (int): Doc index. n_shards (int): Total number of shards. RETURNS (Dict[str, Any]): Data injected into prompt template. """ diff --git a/spacy_llm/tasks/entity_linker/parser.py b/spacy_llm/tasks/entity_linker/parser.py index b3c4076a..380ee837 100644 --- a/spacy_llm/tasks/entity_linker/parser.py +++ b/spacy_llm/tasks/entity_linker/parser.py @@ -8,36 +8,47 @@ def parse_responses_v1( - task: EntityLinkerTask, docs: Iterable[Doc], responses: Iterable[str] -) -> Iterable[List[Span]]: + task: EntityLinkerTask, + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[List[List[Span]]]: """Parses LLM responses for spacy.EntityLinker.v1. task (EntityLinkerTask): Task instance. - docs (Iterable[Doc]): Corresponding Doc instances. - responses (Iterable[str]): LLM responses. - RETURNS (Iterable[List[Span]]): Entity spans per doc. + shards (Iterable[Iterable[Doc]]): Doc shards. + responses (Iterable[Iterable[str]]): LLM responses. + RETURNS (Iterable[List[List[Span]]): Entity spans per shard. """ - for i_doc, (doc, prompt_response) in enumerate(zip(docs, responses)): - solutions = [ - sol.replace("::: ", "")[1:-1] - for sol in re.findall(r"::: <.*>", prompt_response) - ] - - # Set ents anew by copying them and specifying the KB ID. - ents = [ - ent - for i_ent, ent in enumerate(doc.ents) - if task.has_ent_cands[i_doc][i_ent] - ] - yield [ - Span( - doc=doc, - start=ent.start, - end=ent.end, - label=ent.label, - vector=ent.vector, - vector_norm=ent.vector_norm, - kb_id=solution if solution != "NIL" else EntityLinker.NIL, + for i_doc, (shards_for_doc, responses_for_doc) in enumerate(zip(shards, responses)): + results_for_doc: List[List[Span]] = [] + for i_shard, (shard, response) in enumerate( + zip(shards_for_doc, responses_for_doc) + ): + solutions = [ + sol.replace("::: ", "")[1:-1] + for sol in re.findall(r"::: <.*>", response) + ] + + # Set ents anew by copying them and specifying the KB ID. + ents = [ + ent + for i_ent, ent in enumerate(shard.ents) + if task.has_ent_cands[i_doc][i_ent] + ] + + results_for_doc.append( + [ + Span( + doc=shard, + start=ent.start, + end=ent.end, + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=solution if solution != "NIL" else EntityLinker.NIL, + ) + for ent, solution in zip(ents, solutions) + ] ) - for ent, solution in zip(ents, solutions) - ] + + yield results_for_doc diff --git a/spacy_llm/tasks/entity_linker/registry.py b/spacy_llm/tasks/entity_linker/registry.py index 10e34ed0..df98b6f2 100644 --- a/spacy_llm/tasks/entity_linker/registry.py +++ b/spacy_llm/tasks/entity_linker/registry.py @@ -5,12 +5,15 @@ from spacy.scorer import Scorer from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, ShardMapper, ShardReducer +from ...ty import TaskResponseParser +from ..util.sharding import make_shard_mapper from .candidate_selector import KBCandidateSelector from .parser import parse_responses_v1 from .task import DEFAULT_EL_TEMPLATE_V1, EntityLinkerTask from .ty import EntDescReader, InMemoryLookupKBLoader -from .util import ELExample, KBFileLoader, KBObjectLoader, ent_desc_reader_csv, score +from .util import ELExample, KBFileLoader, KBObjectLoader, ent_desc_reader_csv +from .util import reduce_shards_to_doc, score @registry.llm_tasks("spacy.EntityLinker.v1") @@ -19,6 +22,8 @@ def make_entitylinker_task( parse_responses: Optional[TaskResponseParser[EntityLinkerTask]] = None, prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, + shard_mapper: Optional[ShardMapper] = None, + shard_reducer: Optional[ShardReducer] = None, scorer: Optional[Scorer] = None, ): """EntityLinker.v1 task factory. @@ -28,6 +33,8 @@ def make_entitylinker_task( prompt_example_type (Optional[Type[FewshotExample]]): Type to use for fewshot examples. examples (ExamplesConfigType): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. + shard_mapper (Optional[ShardMapper]): Maps docs to shards if they don't fit into the model context. + shard_reducer (Optional[ShardReducer]): Reduces doc shards back into one doc instance. scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples @@ -50,6 +57,8 @@ def make_entitylinker_task( parse_responses=parse_responses or parse_responses_v1, prompt_example_type=example_type, prompt_examples=examples, + shard_mapper=shard_mapper or make_shard_mapper(), + shard_reducer=shard_reducer or make_shard_reducer(), scorer=scorer or score, ) @@ -114,3 +123,8 @@ def make_kb_file_loader(path: Union[str, Path]) -> KBFileLoader: RETURNS (KBFileLoader): Loader instance. """ return KBFileLoader(path=path) + + +@registry.llm_misc("spacy.EntityLinkerShardReducer.v1") +def make_shard_reducer() -> ShardReducer: + return reduce_shards_to_doc diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index c17dfcd7..930cbbfc 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -1,13 +1,12 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type -import jinja2 from spacy import Language, Vocab from spacy.pipeline import EntityLinker from spacy.tokens import Doc, Span from spacy.training import Example from ...compat import Self -from ...ty import FewshotExample, Scorer, TaskResponseParser +from ...ty import FewshotExample, Scorer, ShardMapper, ShardReducer, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template from .ty import CandidateSelector, Entity, InitializableCandidateSelector @@ -22,6 +21,8 @@ def __init__( prompt_example_type: Type[FewshotExample[Self]], prompt_examples: Optional[List[FewshotExample[Self]]], template: str, + shard_mapper: ShardMapper, + shard_reducer: ShardReducer[Self], scorer: Scorer, ): """Default entity linking task. @@ -30,6 +31,8 @@ def __init__( prompt_example_type (Type[FewshotExample[Self]]): Type to use for fewshot examples. prompt_examples (Optional[List[FewshotExample[Self]]]): Optional list of few-shot examples to include in prompts. template (str): Prompt template passed to the model. + shard_mapper (ShardMapper): Maps docs to shards if they don't fit into the model context. + shard_reducer (ShardReducer[Self]): Reduces doc shards back into one doc instance. scorer (Scorer): Scorer function. """ super().__init__( @@ -37,14 +40,19 @@ def __init__( prompt_example_type=prompt_example_type, template=template, prompt_examples=prompt_examples, + shard_mapper=shard_mapper, + shard_reducer=shard_reducer, ) self._scorer = scorer self._candidate_selector: Optional[CandidateSelector] = None # Exclude mentions without candidates from prompt, if set. Mostly used for internal debugging. self._auto_nil = True - # Store, per doc and entity, whether candidates could be found. - self._has_ent_cands: List[List[bool]] = [] + # Store, per doc and entity, whether candidates could be found and candidates themselves. + self._has_ent_cands_by_doc: List[List[bool]] = [] + self._ents_cands_by_doc: List[List[List[Entity]]] = [] + self._has_ent_cands_by_shard: List[List[List[bool]]] = [] + self._ents_cands_by_shard: List[List[List[List[Entity]]]] = [] def initialize( self, @@ -86,80 +94,164 @@ def set_candidate_selector( if isinstance(self._candidate_selector, InitializableCandidateSelector): self._candidate_selector.initialize(vocab) - def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[str]: - environment = jinja2.Environment() - _template = environment.from_string(self._template) - # Reset auto-nil attributes for new batch of docs. - self._has_ent_cands = [] + def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: + ( + self._ents_cands_by_doc, + self._has_ent_cands_by_doc, + ) = self._find_entity_candidates(docs) + # Reset shard-wise candidate info. Will be set for each shard individually in _get_prompt_data(). We cannot + # update it here, as we don't know yet how the shards will look like. + self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + preprocessed_docs: List[Doc] = [] - for i_doc, doc in enumerate(docs): - cands_ents, _ = self.fetch_entity_info(doc) - # Determine which ents have candidates and should be included in prompt. - has_cands = [ - {cand_ent.id for cand_ent in cand_ents} != {EntityLinker.NIL} - or not self._auto_nil - for cand_ents in cands_ents + for i, doc in enumerate(docs): + preprocessed_doc = Doc( + doc.vocab, + words=EntityLinkerTask.highlight_ents_in_text( + doc, self._has_ent_cands_by_doc[i] + ).split(), + ) + preprocessed_doc.ents = [ + Span( + doc=preprocessed_doc, + start=ent.start, + end=ent.end, + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=EntityLinker.NIL, + ) + for ent in doc.ents ] - self._has_ent_cands.append(has_cands) - - # To improve: if a doc has no entities (with candidates), skip prompt altogether? - yield _template.render( - text=EntityLinkerTask.highlight_ents_in_text(doc, has_cands), - mentions_str=", ".join( - [f"*{mention}*" for hc, mention in zip(has_cands, doc.ents) if hc] - ), - mentions=[ent.text for hc, ent in zip(has_cands, doc.ents) if hc], - entity_descriptions=[ - [ent.description for ent in ents] - for hc, ents in zip(has_cands, cands_ents) - if hc - ], - entity_ids=[ - [ent.id for ent in ents] - for hc, ents in zip(has_cands, cands_ents) - if hc - ], - prompt_examples=self._prompt_examples, + preprocessed_docs.append(preprocessed_doc) + + return preprocessed_docs + + def _find_entity_candidates( + self, docs: Iterable[Doc] + ) -> Tuple[List[List[List[Entity]]], List[List[bool]]]: + """Determine entity candidates for all entity mentions in docs. + docs (Iterable[Doc]): Docs with entities to select candidates for. + RETURNS (Tuple[List[List[List[Entity]]], List[List[bool]]]): (1) list of candidate entities for each doc and + entity, (2) list of flag whether candidates could be found per each doc and entitiy. + """ + ents_cands: List[List[List[Entity]]] = [] + has_cands: List[List[bool]] = [] + + for doc in docs: + ents_cands.append(self.fetch_entity_info(doc)[0]) + # Determine which ents have candidates and should be included in prompt. + has_cands.append( + [ + {cand_ent.id for cand_ent in cand_ents} != {EntityLinker.NIL} + or not self._auto_nil + for cand_ents in ents_cands[-1] + ] + ) + + return ents_cands, has_cands + + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: + # It's not ideal that we have to run candidate selection again here - but due to (1) us wanting to know whether + # all entities have candidates before sharding and, more importantly, (2) some entities maybe being split up in + # the sharding process it's cleaner to look for candidates again. + if n_shards == 1: + # If only one shard: shard is identical to original doc, so we don't have to rerun candidate search. + ents_cands, has_cands = ( + self._ents_cands_by_doc[i_doc], + self._has_ent_cands_by_doc[i_doc], ) + else: + cands_info = self._find_entity_candidates([shard]) + ents_cands, has_cands = cands_info[0][0], cands_info[1][0] + + # Update shard-wise candidate info so it can be reused during parsing. + if len(self._ents_cands_by_shard[i_doc]) == 0: + self._ents_cands_by_shard[i_doc] = [[] * n_shards] + self._has_ent_cands_by_shard[i_doc] = [[] * n_shards] + self._ents_cands_by_shard[i_doc][i_shard] = ents_cands + self._has_ent_cands_by_shard[i_doc][i_shard] = has_cands + + return { + "mentions_str": ", ".join( + [mention.text for hc, mention in zip(has_cands, shard.ents) if hc] + ), + "mentions": [ + # Due to retokenization of doc with entity highlighting entity mentions are wrapped in "*". + ent.text + if not (ent.text.startswith("*") and ent.text.endswith("*")) + else ent.text[1:-1] + for hc, ent in zip(has_cands, shard.ents) + if hc + ], + "entity_descriptions": [ + [ent.description for ent in ents] + for hc, ents in zip(has_cands, ents_cands) + if hc + ], + "entity_ids": [ + [ent.id for ent in ents] + for hc, ents in zip(has_cands, ents_cands) + if hc + ], + } def parse_responses( - self, docs: Iterable[Doc], responses: Iterable[str] + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - for i_doc, (doc, ent_spans) in enumerate( - zip(docs, self._parse_responses(self, docs=docs, responses=responses)) + shards_teed = self._tee_2d_iterable(shards, 2) + parsed_responses = self._parse_responses(self, shards_teed[1], responses) + + for i_doc, (shards_for_doc, ent_spans_for_doc) in enumerate( + zip(shards_teed[0], parsed_responses) ): - gen_nil_span: Callable[[Span], Span] = lambda ent: Span( # noqa: E731 - doc=doc, - start=ent.start, - end=ent.end, - label=ent.label, - vector=ent.vector, - vector_norm=ent.vector_norm, - kb_id=EntityLinker.NIL, - ) + updated_shards_for_doc: List[Doc] = [] + for i_shard, (shard, ent_spans) in enumerate( + zip(shards_for_doc, ent_spans_for_doc) + ): + gen_nil_span: Callable[[Span], Span] = lambda ent: Span( # noqa: E731 + doc=shard, + start=ent.start, + end=ent.end, + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=EntityLinker.NIL, + ) - # If numbers of ents parsed from LLM response + ents without candidates and number of ents in doc don't - # align, skip doc (most likely LLM parsing failed, no guarantee KB IDs can be assigned to correct ents). - # This can happen when the LLM fails to list solutions for all entities. - all_entities_resolved = len(ent_spans) + sum( - [not is_in_prompt for is_in_prompt in self._has_ent_cands[i_doc]] - ) == len(doc.ents) - - # Fuse entities with (i. e. inferred by the LLM) and without candidates (i. e. auto-niled). - # If entity was not included in prompt, as there were no candidates - fill in NIL for this entity. - # If numbers of inferred and auto-niled entities don't line up with total number of entities, there is no - # guaranteed way to assign a partially resolved list of entities - # correctly. - # Else: entity had candidates and was included in prompt - fill in resolved KB ID. - ent_spans_iter = iter(ent_spans) - doc.ents = [ - gen_nil_span(ent) - if not (all_entities_resolved and self._has_ent_cands[i_doc][i_ent]) - else next(ent_spans_iter) - for i_ent, ent in enumerate(doc.ents) - ] + # If numbers of ents parsed from LLM response + ents without candidates and number of ents in doc don't + # align, skip doc (most likely LLM parsing failed, no guarantee KB IDs can be assigned to correct ents). + # This can happen when the LLM fails to list solutions for all entities. + all_entities_resolved = len(ent_spans) + sum( + [ + not is_in_prompt + for is_in_prompt in self._has_ent_cands_by_shard[i_doc][i_shard] + ] + ) == len(shard.ents) + + # Fuse entities with (i. e. inferred by the LLM) and without candidates (i. e. auto-niled). + # If entity was not included in prompt, as there were no candidates - fill in NIL for this entity. + # If numbers of inferred and auto-niled entities don't line up with total number of entities, there is + # no guaranteed way to assign a partially resolved list of entities + # correctly. + # Else: entity had candidates and was included in prompt - fill in resolved KB ID. + ent_spans_iter = iter(ent_spans) + shard.ents = [ + gen_nil_span(ent) + if not ( + all_entities_resolved + and self._has_ent_cands_by_shard[i_doc][i_shard][i_ent] + ) + else next(ent_spans_iter) + for i_ent, ent in enumerate(shard.ents) + ] - yield doc + updated_shards_for_doc.append(shard) + + yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] def scorer(self, examples: Iterable[Example]) -> Dict[str, Any]: return self._scorer(examples) @@ -192,7 +284,11 @@ def highlight_ents_in_text( text = ( text[: ent.start_char + i * 2] - + f"*{ent.text}*" + + ( + f"*{ent.text}*" + if not (ent.text.startswith("*") and ent.text.endswith("*")) + else ent.text + ) + text[ent.end_char + i * 2 :] ) i += 1 @@ -250,4 +346,4 @@ def has_ent_cands(self) -> List[List[bool]]: """Returns flags indicating whether documents' entities' have candidates in KB. RETURNS (List[List[bool]]): Flags indicating whether documents' entities' have candidates in KB. """ - return self._has_ent_cands + return self._has_ent_cands_by_doc diff --git a/spacy_llm/tasks/entity_linker/util.py b/spacy_llm/tasks/entity_linker/util.py index dce404b6..75511484 100644 --- a/spacy_llm/tasks/entity_linker/util.py +++ b/spacy_llm/tasks/entity_linker/util.py @@ -11,6 +11,7 @@ from spacy.kb import InMemoryLookupKB from spacy.pipeline import EntityLinker from spacy.scorer import Scorer +from spacy.tokens import Doc from spacy.training import Example from ...compat import Self @@ -34,7 +35,14 @@ def mentions_str(self) -> str: """Returns stringified version of all mentions. RETURNS (str): Stringified version of all mentions. """ - return ", ".join([f"*{mention}*" for mention in self.mentions]) + return ", ".join( + [ + f"*{mention}*" + if not (mention.startswith("*") and mention.endswith("*")) + else mention + for mention in self.mentions + ] + ) @classmethod def generate(cls, example: Example, task: EntityLinkerTask) -> Optional[Self]: @@ -196,3 +204,13 @@ def __call__(self, vocab: Vocab) -> Tuple[InMemoryLookupKB, DescFormat]: raise err return kb, {qid: entities[qid].get("desc") for qid in qids} + + +def reduce_shards_to_doc(task: EntityLinkerTask, shards: Iterable[Doc]) -> Doc: + """Reduces shards to docs for LemmaTask. + task (EntityLinkerTask): Task. + shards (Iterable[Doc]): Shards to reduce to single doc instance. + RETURNS (Doc): Fused doc instance. + """ + # Entities are additive, so we can just merge shards. + return Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index b90c99cf..b86e5ad9 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -79,7 +79,9 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: return preprocessed_docs - def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/span/task.py b/spacy_llm/tasks/span/task.py index 56b46724..e8bcb407 100644 --- a/spacy_llm/tasks/span/task.py +++ b/spacy_llm/tasks/span/task.py @@ -70,7 +70,9 @@ def __init__( if self._prompt_examples: self._prompt_examples = list(self._check_label_consistency(self)) - def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "description": self._description, "labels": list(self._label_dict.values()), diff --git a/spacy_llm/tasks/summarization/task.py b/spacy_llm/tasks/summarization/task.py index f144b302..c6900ce0 100644 --- a/spacy_llm/tasks/summarization/task.py +++ b/spacy_llm/tasks/summarization/task.py @@ -84,7 +84,9 @@ def _check_prompt_example_summary_len(self) -> None: f"LLM will likely produce responses that are too long." ) - def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: if self._check_example_summaries: self._check_prompt_example_summary_len() self._check_example_summaries = False diff --git a/spacy_llm/tasks/textcat/task.py b/spacy_llm/tasks/textcat/task.py index ef29d79e..9e21238a 100644 --- a/spacy_llm/tasks/textcat/task.py +++ b/spacy_llm/tasks/textcat/task.py @@ -89,7 +89,9 @@ def __init__( ) self._exclusive_classes = True - def _get_prompt_data(self, shard: Doc, n_shards: int) -> Dict[str, Any]: + def _get_prompt_data( + self, shard: Doc, i_shard: int, i_doc: int, n_shards: int + ) -> Dict[str, Any]: return { "labels": list(self._label_dict.values()), "label_definitions": self._label_definitions, diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index 0b264222..b5f0f31b 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -35,9 +35,12 @@ def make_shard_mapper( n_tok_est: NTokenEstimator = n_token_estimator or make_n_token_estimator() def map_doc_to_shards( - doc: Doc, context_length: int, render_template: Callable[[Doc, int], str] + doc: Doc, + i_doc: int, + context_length: int, + render_template: Callable[[Doc, int, int, int], str], ) -> Union[Iterable[Doc], Doc]: - prompt = render_template(doc, 1) + prompt = render_template(doc, 0, i_doc, 1) # If prompt with complete doc too long: split in shards. if n_tok_est(prompt) * buffer_frac > context_length: @@ -63,7 +66,11 @@ def map_doc_to_shards( end_idx = start_idx + int(len(remaining_doc) * fraction) shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) fits_in_context = ( - n_tok_est(render_template(shard, int(1 / fraction))) + n_tok_est( + render_template( + shard, len(shards), i_doc, int(1 / fraction) + ) + ) * buffer_frac <= context_length ) diff --git a/spacy_llm/tests/models/test_anthropic.py b/spacy_llm/tests/models/test_anthropic.py index 6df49ba3..d0bfa794 100644 --- a/spacy_llm/tests/models/test_anthropic.py +++ b/spacy_llm/tests/models/test_anthropic.py @@ -20,6 +20,7 @@ def test_anthropic_api_response_is_correct(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" @@ -49,6 +50,7 @@ def test_anthropic_api_response_when_error(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) @@ -71,4 +73,5 @@ def test_anthropic_error_unsupported_model(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) diff --git a/spacy_llm/tests/models/test_cohere.py b/spacy_llm/tests/models/test_cohere.py index 82260026..dfcb432a 100644 --- a/spacy_llm/tests/models/test_cohere.py +++ b/spacy_llm/tests/models/test_cohere.py @@ -18,6 +18,7 @@ def test_cohere_api_response_is_correct(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" num_prompts = 3 # arbitrary number to check multiple inputs @@ -46,6 +47,7 @@ def test_cohere_api_response_n_generations(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) prompt = "Count the number of characters in this string: hello" @@ -73,6 +75,7 @@ def test_cohere_api_response_when_error(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) @@ -90,4 +93,5 @@ def test_cohere_error_unsupported_model(): max_tries=10, interval=5.0, max_request_time=20, + context_length=None, ) diff --git a/spacy_llm/tests/models/test_langchain.py b/spacy_llm/tests/models/test_langchain.py index fd48e0bb..95019504 100644 --- a/spacy_llm/tests/models/test_langchain.py +++ b/spacy_llm/tests/models/test_langchain.py @@ -23,7 +23,8 @@ def test_initialization(): """Test initialization and simple run""" nlp = spacy.blank("en") nlp.add_pipe("llm", config=PIPE_CFG) - nlp("This is a test.") + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp("This is a test.") @pytest.mark.external @@ -49,4 +50,5 @@ def test_initialization_azure_openai(): nlp = spacy.blank("en") nlp.add_pipe("llm", config=_pipe_cfg) - nlp("This is a test.") + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp("This is a test.") diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index ed150589..5c57def0 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -183,8 +183,8 @@ def parse_responses( assert len(record) == 2 assert ( str(record[0].message) - == "First type in `Iterable[Tuple[...]] returned from `task.generate_prompts()` (`typing.Iterable[int]`) " - "doesn't match type expected by `model` (`typing.Iterable[str]`)." + == "First type in value returned from `task.generate_prompts()` (`typing.Iterable[int]`) doesn't match type " + "expected by `model` (`typing.Iterable[str]`)." ) assert ( str(record[1].message) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index b17848d0..6da79570 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -722,7 +722,8 @@ def test_ner_scoring(noop_config, n_detections): examples.append(Example(predicted, reference)) - scores = nlp.evaluate(examples) + with pytest.warns(UserWarning, match="Task supports sharding"): + scores = nlp.evaluate(examples) assert scores["ents_p"] == n_detections / 2 diff --git a/spacy_llm/tests/tasks/legacy/test_spancat.py b/spacy_llm/tests/tasks/legacy/test_spancat.py index 4b70ee46..d146ead5 100644 --- a/spacy_llm/tests/tasks/legacy/test_spancat.py +++ b/spacy_llm/tests/tasks/legacy/test_spancat.py @@ -528,7 +528,8 @@ def test_spancat_scoring(noop_config, n_detections): examples.append(Example(predicted, reference)) - scores = nlp.evaluate(examples) + with pytest.warns(UserWarning, match="Task supports sharding"): + scores = nlp.evaluate(examples) assert scores["spans_sc_p"] == n_detections / 2 diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 625d9105..dcc189f6 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -352,8 +352,8 @@ def make_doc() -> Doc: nlp.components[0][1]._task._auto_nil = False doc = nlp(make_doc()) assert ( - f"- For *Foo*:\n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" - in doc.user_data["llm_io"]["llm"]["prompt"] + f"- For *Foo*:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" + in doc.user_data["llm_io"]["llm"]["prompt"].replace("\\", "") ) assert doc.ents[0].kb_id_ == EntityLinker.NIL # Sometimes GPT-3.5 doesn't manage to include the NIL prediction, in which case all entities are set to NIL. @@ -427,7 +427,7 @@ def test_jinja_template_rendering_without_examples(tmp_path): ) ) el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) - prompt = list(el_task.generate_prompts([doc]))[0] + prompt = list(el_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip().replace(" \n", "\n") @@ -497,7 +497,7 @@ def test_jinja_template_rendering_with_examples(examples_path, tmp_path): ) ) el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) - prompt = list(el_task.generate_prompts([doc]))[0] + prompt = list(el_task.generate_prompts([doc]))[0][0][0] assert ( prompt.strip().replace(" \n", "\n") @@ -601,7 +601,7 @@ def test_external_template_actually_loads(tmp_path): el_task._candidate_selector.initialize(spacy.load(tmp_path).vocab) assert ( - list(el_task.generate_prompts([doc]))[0].strip() + list(el_task.generate_prompts([doc]))[0][0][0].strip() == f""" This is a test entity linking template. Here is the text: {text} diff --git a/spacy_llm/tests/tasks/test_textcat.py b/spacy_llm/tests/tasks/test_textcat.py index e4ddad3a..ff63fccf 100644 --- a/spacy_llm/tests/tasks/test_textcat.py +++ b/spacy_llm/tests/tasks/test_textcat.py @@ -658,7 +658,8 @@ def b(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: examples.append(Example(predicted, reference)) - scores = nlp.evaluate(examples) + with pytest.warns(UserWarning, match="Task supports sharding"): + scores = nlp.evaluate(examples) pos = n_insults / len(INSULTS) diff --git a/spacy_llm/tests/test_combinations.py b/spacy_llm/tests/test_combinations.py index bd4a871a..d4471ab6 100644 --- a/spacy_llm/tests/test_combinations.py +++ b/spacy_llm/tests/test_combinations.py @@ -49,7 +49,19 @@ def test_combinations(model: str, task: str, n_process: int): assert name == "llm" assert isinstance(component, LLMWrapper) - nlp("This is a test.") - list( - nlp.pipe(["This is a second test", "This is a third test"], n_process=n_process) - ) + if model.startswith("langchain"): + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp("This is a test.") + list( + nlp.pipe( + ["This is a second test", "This is a third test"], + n_process=n_process, + ) + ) + else: + nlp("This is a test.") + list( + nlp.pipe( + ["This is a second test", "This is a third test"], n_process=n_process + ) + ) diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index d0d28bb2..0889c1c2 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -25,8 +25,8 @@ ] NTokenEstimator = Callable[[str], int] ShardMapper = Callable[ - # Requires doc, context length and callable for rendering template from doc shard text. - [Doc, int, Callable[[Doc, int], str]], + # Requires doc, doc index, context length and callable for rendering template from doc shard text. + [Doc, int, int, Callable[[Doc, int, int, int], str]], # Returns each shard as a doc. Iterable[Doc], ] @@ -92,16 +92,6 @@ def __call__(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: """ -# todo -# x change to llmtask -# x add llmtask -# x fix task typing structures -# x fix model data handling -# x don't expect doc back from nonsharding tasks -# x run tests with to sharding and non-sharding nooptask -# - fix inevitable typing check issues - - @runtime_checkable class ShardingLLMTask(Protocol): def generate_prompts( @@ -130,7 +120,7 @@ def parse_responses( @runtime_checkable -class LLMTask(Protocol): +class NonshardingLLMTask(Protocol): def generate_prompts(self, docs: Iterable[Doc]) -> Iterable[_PromptType]: """Generate prompts from docs. docs (Iterable[Doc]): Docs to generate prompts from. @@ -148,14 +138,25 @@ def parse_responses( """ -TaskContraT = TypeVar("TaskContraT", bound=ShardingLLMTask, contravariant=True) +@runtime_checkable +class LLMTask(Protocol): + generate_prompts: Callable[..., Iterable[Any]] + parse_responses: Callable[..., Iterable[Doc]] + + +TaskContraT = TypeVar( + "TaskContraT", bound=Union[ShardingLLMTask, LLMTask], contravariant=True +) +ShardingTaskContraT = TypeVar( + "ShardingTaskContraT", bound=ShardingLLMTask, contravariant=True +) @runtime_checkable -class ShardReducer(Protocol[TaskContraT]): +class ShardReducer(Protocol[ShardingTaskContraT]): """Generic protocol for tasks' shard reducer.""" - def __call__(self, task: TaskContraT, shards: Iterable[Doc]) -> Doc: + def __call__(self, task: ShardingTaskContraT, shards: Iterable[Doc]) -> Doc: """Merges shard to single Doc.""" ... From 03055c5a2a975c3930cf5c23bebf47c6e4541f3f Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 23 Nov 2023 18:48:23 +0100 Subject: [PATCH 26/51] Fix EL tokenization and highlighting partially. --- spacy_llm/tasks/entity_linker/parser.py | 2 +- spacy_llm/tasks/entity_linker/task.py | 151 +++++++++++------- spacy_llm/tasks/entity_linker/util.py | 11 +- .../tasks/templates/entity_linker.v1.jinja | 6 +- spacy_llm/tests/tasks/test_entity_linker.py | 4 +- usage_examples/tests/test_readme_examples.py | 3 +- usage_examples/tests/test_usage_examples.py | 7 +- 7 files changed, 108 insertions(+), 76 deletions(-) diff --git a/spacy_llm/tasks/entity_linker/parser.py b/spacy_llm/tasks/entity_linker/parser.py index 380ee837..54d1c19c 100644 --- a/spacy_llm/tasks/entity_linker/parser.py +++ b/spacy_llm/tasks/entity_linker/parser.py @@ -33,7 +33,7 @@ def parse_responses_v1( ents = [ ent for i_ent, ent in enumerate(shard.ents) - if task.has_ent_cands[i_doc][i_ent] + if task.has_ent_cands_by_shard[i_doc][i_shard][i_ent] ] results_for_doc.append( diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index 930cbbfc..5c4ebf55 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type from spacy import Language, Vocab from spacy.pipeline import EntityLinker @@ -103,30 +103,11 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: # update it here, as we don't know yet how the shards will look like. self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] - preprocessed_docs: List[Doc] = [] - - for i, doc in enumerate(docs): - preprocessed_doc = Doc( - doc.vocab, - words=EntityLinkerTask.highlight_ents_in_text( - doc, self._has_ent_cands_by_doc[i] - ).split(), - ) - preprocessed_doc.ents = [ - Span( - doc=preprocessed_doc, - start=ent.start, - end=ent.end, - label=ent.label, - vector=ent.vector, - vector_norm=ent.vector_norm, - kb_id=EntityLinker.NIL, - ) - for ent in doc.ents - ] - preprocessed_docs.append(preprocessed_doc) - return preprocessed_docs + return [ + EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) + for i, doc in enumerate(docs) + ] def _find_entity_candidates( self, docs: Iterable[Doc] @@ -179,14 +160,7 @@ def _get_prompt_data( "mentions_str": ", ".join( [mention.text for hc, mention in zip(has_cands, shard.ents) if hc] ), - "mentions": [ - # Due to retokenization of doc with entity highlighting entity mentions are wrapped in "*". - ent.text - if not (ent.text.startswith("*") and ent.text.endswith("*")) - else ent.text[1:-1] - for hc, ent in zip(has_cands, shard.ents) - if hc - ], + "mentions": [ent.text for hc, ent in zip(has_cands, shard.ents) if hc], "entity_descriptions": [ [ent.description for ent in ents] for hc, ents in zip(has_cands, ents_cands) @@ -249,7 +223,10 @@ def parse_responses( for i_ent, ent in enumerate(shard.ents) ] - updated_shards_for_doc.append(shard) + # Remove entity highlights in shards. + updated_shards_for_doc.append( + EntityLinkerTask.unhighlight_ents_in_doc(shard) + ) yield self._shard_reducer(self, updated_shards_for_doc) # type: ignore[arg-type] @@ -261,39 +238,99 @@ def _cfg_keys(self) -> List[str]: return ["_template"] @staticmethod - def highlight_ents_in_text( + def highlight_ents_in_doc( doc: Doc, include_ents: Optional[List[bool]] = None - ) -> str: - """Highlights entities in doc text with **. + ) -> Doc: + """Highlights entities in doc by wrapping them in **. doc (Doc): Doc whose entities are to be highlighted. include_ents (Optional[List[bool]]): Whether to include entities with the corresponding indices. If None, all are included. - RETURNS (str): Text with highlighted entities. + RETURNS (Doc): Doc with highlighted entities. """ if include_ents is not None and len(include_ents) != len(doc.ents): raise ValueError( f"`include_ents` has {len(include_ents)} entries, but {len(doc.ents)} are required." ) - text = doc.text - i = 0 - for ent in doc.ents: - # Skip if ent is not supposed to be included. - if include_ents is not None and not include_ents[i]: - continue - - text = ( - text[: ent.start_char + i * 2] - + ( - f"*{ent.text}*" - if not (ent.text.startswith("*") and ent.text.endswith("*")) - else ent.text + ents_to_highlight_idx = [ + i + for i, ent in enumerate(doc.ents) + if (include_ents is None or include_ents[i]) + ] + ents_idx = [(ent.start, ent.end) for ent in doc.ents] + + # Include *-marker as tokens. Update entity indices. + i_ent = 0 + new_ent_idx: List[Tuple[int, int]] = [] + token_texts: List[str] = [] + to_highlight = i_ent in ents_to_highlight_idx + offset = 0 + for token in doc: + if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][1]: + if to_highlight: + token_texts.append("*") + offset += 1 + i_ent += 1 + to_highlight = i_ent in ents_to_highlight_idx + if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][0]: + if to_highlight: + token_texts.append("*") + offset += 1 + new_ent_idx.append( + (ents_idx[i_ent][0] + offset, ents_idx[i_ent][1] + offset) ) - + text[ent.end_char + i * 2 :] + token_texts.append(token.text) + + # Create doc with new tokens and entities. + highlighted_doc = Doc(doc.vocab, words=token_texts) + highlighted_doc.ents = [ + Span( + doc=highlighted_doc, + start=new_ent_idx[i][0], + end=new_ent_idx[i][1], + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=ent.kb_id_, + ) + for i, ent in enumerate(doc.ents) + ] + + return highlighted_doc + + @staticmethod + def unhighlight_ents_in_doc(doc: Doc) -> Doc: + """Remove entity highlighting (* wrapping) in doc. + doc (Doc): Doc whose entities are to be highlighted. + RETURNS (Doc): Doc with highlighted entities. + """ + highlight_idx: Set[int] = {ent.start - 1 for ent in doc.ents} | { + ent.end for ent in doc.ents + } + ent_idx = [ + (ent.start - i * 2 - 1, ent.end - i * 2 - 1) + for i, ent in enumerate(doc.ents) + ] + + # Create doc with new tokens and entities. + unhighlighted_doc = Doc( + doc.vocab, + words=[token.text for token in doc if token.i not in highlight_idx], + ) + unhighlighted_doc.ents = [ + Span( + doc=unhighlighted_doc, + start=ent_idx[i][0], + end=ent_idx[i][1], + label=ent.label, + vector=ent.vector, + vector_norm=ent.vector_norm, + kb_id=ent.kb_id_, ) - i += 1 + for i, ent in enumerate(doc.ents) + ] - return text + return unhighlighted_doc def _require_candidate_selector(self) -> None: """Raises an error if candidate selector is not available.""" @@ -342,8 +379,8 @@ def fetch_entity_info( return cand_entity_info, correct_ent_ids @property - def has_ent_cands(self) -> List[List[bool]]: - """Returns flags indicating whether documents' entities' have candidates in KB. - RETURNS (List[List[bool]]): Flags indicating whether documents' entities' have candidates in KB. + def has_ent_cands_by_shard(self) -> List[List[List[bool]]]: + """Returns flags indicating whether shards' entities' have candidates in KB. + RETURNS (List[List[List[bool]]]): Flags indicating whether shards' entities' have candidates in KB. """ - return self._has_ent_cands_by_doc + return self._has_ent_cands_by_shard diff --git a/spacy_llm/tasks/entity_linker/util.py b/spacy_llm/tasks/entity_linker/util.py index 75511484..57c4a95a 100644 --- a/spacy_llm/tasks/entity_linker/util.py +++ b/spacy_llm/tasks/entity_linker/util.py @@ -35,14 +35,7 @@ def mentions_str(self) -> str: """Returns stringified version of all mentions. RETURNS (str): Stringified version of all mentions. """ - return ", ".join( - [ - f"*{mention}*" - if not (mention.startswith("*") and mention.endswith("*")) - else mention - for mention in self.mentions - ] - ) + return ", ".join([f"*{mention}*" for mention in self.mentions]) @classmethod def generate(cls, example: Example, task: EntityLinkerTask) -> Optional[Self]: @@ -68,7 +61,7 @@ def generate(cls, example: Example, task: EntityLinkerTask) -> Optional[Self]: assert all([sol is not None for sol in solutions]) return ELExample( - text=EntityLinkerTask.highlight_ents_in_text(example.reference), + text=EntityLinkerTask.highlight_ents_in_doc(example.reference).text, mentions=mentions, entity_descriptions=[ [ent.description for ent in ents] for ents in cands_ents diff --git a/spacy_llm/tasks/templates/entity_linker.v1.jinja b/spacy_llm/tasks/templates/entity_linker.v1.jinja index 7958efc7..c7b3647c 100644 --- a/spacy_llm/tasks/templates/entity_linker.v1.jinja +++ b/spacy_llm/tasks/templates/entity_linker.v1.jinja @@ -21,7 +21,7 @@ MENTIONS: {{ example.mention_str }} ENTITIES: {%- for ent_descs in example.entity_descriptions -%} {% set mention_i = loop.index0 %} -- For *{{ example.mentions[loop.index0] }}*: +- For * {{ example.mentions[loop.index0] }} *: {%- for ent_desc in ent_descs -%} {# whitespace #} {{ example.entity_ids[mention_i][loop.index0] }}. {{ ent_desc }} @@ -51,7 +51,7 @@ REASONING: SOLUTION: {%- for solution in example.solutions -%} {# whitespace #} -*{{ example.mentions[loop.index0] }}* ::: <{{ solution }}> +* {{ example.mentions[loop.index0] }} * ::: <{{ solution }}> {%- endfor -%} {# whitespace #} {# whitespace #} @@ -69,7 +69,7 @@ MENTIONS: {{ mentions_str }} ENTITIES: {%- for ent_descs in entity_descriptions -%} {% set mention_i = loop.index0 %} -- For *{{ mentions[loop.index0] }}*: +- For * {{ mentions[loop.index0] }} *: {%- for ent_desc in ent_descs -%} {# whitespace #} {{ entity_ids[mention_i][loop.index0] }}. {{ ent_desc }} diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index dcc189f6..60efeff7 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -352,7 +352,7 @@ def make_doc() -> Doc: nlp.components[0][1]._task._auto_nil = False doc = nlp(make_doc()) assert ( - f"- For *Foo*:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" + f"- For * Foo *:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" in doc.user_data["llm_io"]["llm"]["prompt"].replace("\\", "") ) assert doc.ents[0].kb_id_ == EntityLinker.NIL @@ -678,7 +678,7 @@ def test_ent_highlighting(): ] assert ( - EntityLinkerTask.highlight_ents_in_text(doc) + EntityLinkerTask.highlight_ents_in_doc(doc) == "Alice goes to *Boston* to see the *Boston Celtics* game." ) diff --git a/usage_examples/tests/test_readme_examples.py b/usage_examples/tests/test_readme_examples.py index 72a74d2d..828d1fd9 100644 --- a/usage_examples/tests/test_readme_examples.py +++ b/usage_examples/tests/test_readme_examples.py @@ -200,4 +200,5 @@ def _classify(prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: text_file.write(cfg_str) nlp = assemble(tmpdir / "cfg") - nlp("i'd like a large margherita pizza please") + with pytest.warns(UserWarning, match="Task supports sharding"): + nlp("i'd like a large margherita pizza please") diff --git a/usage_examples/tests/test_usage_examples.py b/usage_examples/tests/test_usage_examples.py index 1b3fa170..9a04e2bc 100644 --- a/usage_examples/tests/test_usage_examples.py +++ b/usage_examples/tests/test_usage_examples.py @@ -118,9 +118,10 @@ def test_ner_v3_openai(): @pytest.mark.filterwarnings("ignore::DeprecationWarning") def test_ner_langchain_openai(): """Test NER LangChain OpenAI usage example.""" - ner_langchain_openai.run_pipeline( - "text", _USAGE_EXAMPLE_PATH / "ner_langchain_openai" / "ner.cfg", False - ) + with pytest.warns(UserWarning, match="Task supports sharding"): + ner_langchain_openai.run_pipeline( + "text", _USAGE_EXAMPLE_PATH / "ner_langchain_openai" / "ner.cfg", False + ) @pytest.mark.external From 4e4a2cdfac96558c11418a1fc8ab86c000e5ac76 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 24 Nov 2023 09:15:17 +0100 Subject: [PATCH 27/51] Fix tokenization and whitespaces for EL task. --- spacy_llm/tasks/entity_linker/task.py | 58 +++++++++++++++---- .../tasks/templates/entity_linker.v1.jinja | 6 +- spacy_llm/tests/tasks/test_entity_linker.py | 11 +++- 3 files changed, 60 insertions(+), 15 deletions(-) diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index 5c4ebf55..76437e64 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type from spacy import Language, Vocab from spacy.pipeline import EntityLinker @@ -158,7 +158,11 @@ def _get_prompt_data( return { "mentions_str": ", ".join( - [mention.text for hc, mention in zip(has_cands, shard.ents) if hc] + [ + f"*{mention.text}*" + for hc, mention in zip(has_cands, shard.ents) + if hc + ] ), "mentions": [ent.text for hc, ent in zip(has_cands, shard.ents) if hc], "entity_descriptions": [ @@ -263,26 +267,39 @@ def highlight_ents_in_doc( i_ent = 0 new_ent_idx: List[Tuple[int, int]] = [] token_texts: List[str] = [] + spaces: List[bool] = [] to_highlight = i_ent in ents_to_highlight_idx offset = 0 + for token in doc: if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][1]: if to_highlight: token_texts.append("*") + spaces.append(spaces[-1]) + spaces[-2] = False offset += 1 i_ent += 1 to_highlight = i_ent in ents_to_highlight_idx if i_ent < len(ents_idx) and token.i == ents_idx[i_ent][0]: if to_highlight: token_texts.append("*") + spaces.append(False) offset += 1 new_ent_idx.append( (ents_idx[i_ent][0] + offset, ents_idx[i_ent][1] + offset) ) token_texts.append(token.text) + spaces.append(token.whitespace_ != "") + + # Cover edge case of doc ending with entity, in which case we need to close the * wrapping. + if len(ents_to_highlight_idx) and doc.ents[ + ents_to_highlight_idx[-1] + ].end == len(doc): + token_texts.append("*") + spaces.append(False) # Create doc with new tokens and entities. - highlighted_doc = Doc(doc.vocab, words=token_texts) + highlighted_doc = Doc(doc.vocab, words=token_texts, spaces=spaces) highlighted_doc.ents = [ Span( doc=highlighted_doc, @@ -304,19 +321,40 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc: doc (Doc): Doc whose entities are to be highlighted. RETURNS (Doc): Doc with highlighted entities. """ - highlight_idx: Set[int] = {ent.start - 1 for ent in doc.ents} | { - ent.end for ent in doc.ents + highlight_start_idx = { + ent.start - 1 + for ent in doc.ents + if ent.start - 1 > 0 and doc[ent.start - 1].text == "*" } - ent_idx = [ - (ent.start - i * 2 - 1, ent.end - i * 2 - 1) - for i, ent in enumerate(doc.ents) - ] + highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"} + highlight_idx = highlight_start_idx | highlight_end_idx + + # Compute entity indices with removed highlights. + ent_idx: List[Tuple[int, int]] = [] + offset = 0 + for ent in doc.ents: + is_highlighted = ent.start - 1 in highlight_start_idx + ent_idx.append( + (ent.start + offset - is_highlighted, ent.end + offset - is_highlighted) + ) + offset -= 2 * is_highlighted # Create doc with new tokens and entities. + tokens = [ + token + for token in doc + if not (token.i in highlight_idx and token.text == "*") + ] unhighlighted_doc = Doc( doc.vocab, - words=[token.text for token in doc if token.i not in highlight_idx], + words=[token.text for token in tokens], + # Use original token space, if token doesn't appear after * highlight. If so, insert space unconditionally. + spaces=[ + token.whitespace_ != "" or token.i + 1 in highlight_idx + for i, token in enumerate(tokens) + ], ) + unhighlighted_doc.ents = [ Span( doc=unhighlighted_doc, diff --git a/spacy_llm/tasks/templates/entity_linker.v1.jinja b/spacy_llm/tasks/templates/entity_linker.v1.jinja index c7b3647c..7958efc7 100644 --- a/spacy_llm/tasks/templates/entity_linker.v1.jinja +++ b/spacy_llm/tasks/templates/entity_linker.v1.jinja @@ -21,7 +21,7 @@ MENTIONS: {{ example.mention_str }} ENTITIES: {%- for ent_descs in example.entity_descriptions -%} {% set mention_i = loop.index0 %} -- For * {{ example.mentions[loop.index0] }} *: +- For *{{ example.mentions[loop.index0] }}*: {%- for ent_desc in ent_descs -%} {# whitespace #} {{ example.entity_ids[mention_i][loop.index0] }}. {{ ent_desc }} @@ -51,7 +51,7 @@ REASONING: SOLUTION: {%- for solution in example.solutions -%} {# whitespace #} -* {{ example.mentions[loop.index0] }} * ::: <{{ solution }}> +*{{ example.mentions[loop.index0] }}* ::: <{{ solution }}> {%- endfor -%} {# whitespace #} {# whitespace #} @@ -69,7 +69,7 @@ MENTIONS: {{ mentions_str }} ENTITIES: {%- for ent_descs in entity_descriptions -%} {% set mention_i = loop.index0 %} -- For * {{ mentions[loop.index0] }} *: +- For *{{ mentions[loop.index0] }}*: {%- for ent_desc in ent_descs -%} {# whitespace #} {{ entity_ids[mention_i][loop.index0] }}. {{ ent_desc }} diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 60efeff7..fdfc80e4 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -352,7 +352,7 @@ def make_doc() -> Doc: nlp.components[0][1]._task._auto_nil = False doc = nlp(make_doc()) assert ( - f"- For * Foo *:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" + f"- For *Foo*:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" in doc.user_data["llm_io"]["llm"]["prompt"].replace("\\", "") ) assert doc.ents[0].kb_id_ == EntityLinker.NIL @@ -678,9 +678,16 @@ def test_ent_highlighting(): ] assert ( - EntityLinkerTask.highlight_ents_in_doc(doc) + EntityLinkerTask.highlight_ents_in_doc(doc).text == "Alice goes to *Boston* to see the *Boston Celtics* game." ) + assert ( + EntityLinkerTask.unhighlight_ents_in_doc( + EntityLinkerTask.highlight_ents_in_doc(doc) + ).text + == doc.text + == text + ) @pytest.mark.external From 694d5dabf14fe72e9e59bd19fd8ab1fbb5701a7e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 24 Nov 2023 11:56:56 +0100 Subject: [PATCH 28/51] Add new registry handlers (with context length and arbitrary model names) for all REST models. --- spacy_llm/models/langchain/model.py | 16 +- spacy_llm/models/rest/anthropic/registry.py | 282 ++++++++++++++++++-- spacy_llm/models/rest/azure/registry.py | 62 ++++- spacy_llm/models/rest/cohere/registry.py | 41 ++- spacy_llm/models/rest/openai/registry.py | 34 +++ 5 files changed, 414 insertions(+), 21 deletions(-) diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 2ff4643e..8da40361 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -18,9 +18,9 @@ def __init__( api: str, config: Dict[Any, Any], query: Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], - Iterable[Iterable[Any]], + ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] ], + context_length: Optional[int], ): """Initializes model instance for integration APIs. name (str): Name of LangChain model to instantiate. @@ -28,9 +28,12 @@ def __init__( config (Dict[Any, Any]): Config passed on to LangChain model. query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing LLM prompts when supplied with the model instance. + context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no no context + length provided, prompts can't be sharded. """ self._langchain_model = LangChain._init_langchain_model(name, api, config) self.query = query + self._context_length = context_length self._check_installation() @classmethod @@ -115,6 +118,7 @@ def langchain_model( ] ] = None, config: Dict[Any, Any] = SimpleFrozenDict(), + context_length: Optional[int] = None, langchain_class_id: str = class_id, ) -> Optional[Callable[[Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]]: try: @@ -123,6 +127,7 @@ def langchain_model( api=langchain_class_id, config=config, query=query_langchain() if query is None else query, + context_length=context_length, ) except ImportError as err: raise ValueError( @@ -132,6 +137,13 @@ def langchain_model( return langchain_model + @property + def context_length(self) -> Optional[int]: + """Returns context length in number of tokens for this model. + RETURNS (Optional[int]): Max. number of tokens in allowed in prompt for the current model. None if unknown. + """ + return self._context_length + @staticmethod def register_models() -> None: """Registers APIs supported by langchain (one API is registered as one model). diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index 4598146e..dc44eb7e 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,6 +7,43 @@ from .model import Anthropic, Endpoints +@registry.llm_models("spacy.Claude-2.v2") +def anthropic_claude_2_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-2", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Anthropic: + """Returns Anthropic instance for 'claude-2' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e.g. "claude-2" or "claude-2-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-2' model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-2.v1") def anthropic_claude_2( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -27,8 +64,7 @@ def anthropic_claude_2( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. """ return Anthropic( name=name, @@ -42,6 +78,43 @@ def anthropic_claude_2( ) +@registry.llm_models("spacy.Claude-1.v2") +def anthropic_claude_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1" or "claude-1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-1.v1") def anthropic_claude_1( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -62,8 +135,7 @@ def anthropic_claude_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-1'. """ return Anthropic( name=name, @@ -77,6 +149,43 @@ def anthropic_claude_1( ) +@registry.llm_models("spacy.Claude-instant-1.v2") +def anthropic_claude_instant_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-instant-1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-instant-1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-instant-1" or "claude-instant-1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-instant-1.v1") def anthropic_claude_instant_1( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -99,8 +208,7 @@ def anthropic_claude_instant_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-instant-1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1'. """ return Anthropic( name=name, @@ -114,6 +222,43 @@ def anthropic_claude_instant_1( ) +@registry.llm_models("spacy.Claude-instant-1-1.v2") +def anthropic_claude_instant_1_1_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-instant-1.1", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-instant-1.1' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-instant-1.1" or "claude-instant-1.1-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1.1'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-instant-1-1.v1") def anthropic_claude_instant_1_1( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -136,8 +281,7 @@ def anthropic_claude_instant_1_1( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-instant-1.1' model using REST to - prompt API. + RETURNS (Anthropic): Anthropic instance for 'claude-instant-1.1' model. """ return Anthropic( name=name, @@ -151,6 +295,43 @@ def anthropic_claude_instant_1_1( ) +@registry.llm_models("spacy.Claude-1-0.v2") +def anthropic_claude_1_0_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.0", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.0' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1.0". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.0'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-1-0.v1") def anthropic_claude_1_0( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -171,8 +352,7 @@ def anthropic_claude_1_0( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.0' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.0' model. """ return Anthropic( name=name, @@ -186,6 +366,43 @@ def anthropic_claude_1_0( ) +@registry.llm_models("spacy.Claude-1-2.v2") +def anthropic_claude_1_2_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.2", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.2' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model to use, e. g. "claude-1.2". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.2'. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-1-2.v1") def anthropic_claude_1_2( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -206,8 +423,7 @@ def anthropic_claude_1_2( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.2' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.2' model. """ return Anthropic( name=name, @@ -221,6 +437,43 @@ def anthropic_claude_1_2( ) +@registry.llm_models("spacy.Claude-1-3.v2") +def anthropic_claude_1_3_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "claude-1.3", + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Anthropic instance for 'claude-1.3' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (str): Name of model variant to use, e. g. "claude-1.3" or "claude-1.3-100k". + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Anthropic): Anthropic instance for 'claude-1.3' model. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Claude-1-3.v1") def anthropic_claude_1_3( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -241,8 +494,7 @@ def anthropic_claude_1_3( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1.3' model using REST to prompt - API. + RETURNS (Anthropic): Anthropic instance for 'claude-1.3' model. """ return Anthropic( name=name, diff --git a/spacy_llm/models/rest/azure/registry.py b/spacy_llm/models/rest/azure/registry.py index fcf6bd72..38df5cb9 100644 --- a/spacy_llm/models/rest/azure/registry.py +++ b/spacy_llm/models/rest/azure/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -8,6 +8,64 @@ _DEFAULT_TEMPERATURE = 0.0 +@registry.llm_models("spacy.Azure.v2") +def azure_openai_v2( + deployment_name: str, + name: str, + base_url: str, + model_type: ModelType, + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + strict: bool = AzureOpenAI.DEFAULT_STRICT, + max_tries: int = AzureOpenAI.DEFAULT_MAX_TRIES, + interval: float = AzureOpenAI.DEFAULT_INTERVAL, + max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, + api_version: str = "2023-05-15", + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Azure OpenAI instance for models deployed on Azure's OpenAI service using REST to prompt API. + + Docs on OpenAI models supported by Azure: + https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#model-summary-table-and-region-availability. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + deployment_name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the + model used by that deployment, as deployment names in Azure OpenAI can be arbitrary. + name (str): Name of the model used by this deployment. This is required to infer the context length that can be + assumed for prompting. + endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like + "https://{prefix}.openai.azure.com/". + model_type (ModelType): Whether the deployed model is a text completetion model (e. g. + text-davinci-003) or a chat model (e. g. gpt-4). + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + api_version (str): API version to use. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (AzureOpenAI): AzureOpenAI instance for deployed model. + + DOCS: https://spacy.io/api/large-language-models#models + """ + return AzureOpenAI( + deployment_name=deployment_name, + name=name, + endpoint=base_url, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + api_version=api_version, + model_type=model_type, + context_length=context_length, + ) + + @registry.llm_models("spacy.Azure.v1") def azure_openai( deployment_name: str, @@ -44,7 +102,7 @@ def azure_openai( at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. api_version (str): API version to use. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + RETURNS (AzureOpenAI): AzureOpenAI instance for deployed model. DOCS: https://spacy.io/api/large-language-models#models """ diff --git a/spacy_llm/models/rest/cohere/registry.py b/spacy_llm/models/rest/cohere/registry.py index d2bc8997..79c711e1 100644 --- a/spacy_llm/models/rest/cohere/registry.py +++ b/spacy_llm/models/rest/cohere/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,6 +7,43 @@ from .model import Cohere, Endpoints +@registry.llm_models("spacy.Command.v2") +def cohere_command_v2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: str = "command", + strict: bool = Cohere.DEFAULT_STRICT, + max_tries: int = Cohere.DEFAULT_MAX_TRIES, + interval: float = Cohere.DEFAULT_INTERVAL, + max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: + """Returns Cohere instance for 'command' model using REST to prompt API. + name (str): Name of model to use, e. g. "command" or "command-light". + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (Cohere): Cohere instance for 'command' model. + """ + return Cohere( + name=name, + endpoint=Endpoints.COMPLETION.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Command.v1") def cohere_command( config: Dict[Any, Any] = SimpleFrozenDict(), @@ -29,7 +66,7 @@ def cohere_command( interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff at each retry. max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. - RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API. + RETURNS (Cohere): Cohere instance for 'command' model. """ return Cohere( name=name, diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 62802687..772a4579 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -230,6 +230,40 @@ def openai_gpt_3_5( ) +@registry.llm_models("spacy.Text-Davinci.v3") +def openai_text_davinci_v3( + config: Dict[Any, Any] = SimpleFrozenDict( + max_tokens=1000, temperature=_DEFAULT_TEMPERATURE + ), + name: str = "text-davinci-003", + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, +) -> OpenAI: + """Returns OpenAI instance for 'text-davinci' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Name of model to use, e. g. "text-davinci-002" or "text-davinci-003". + context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length + natively provided by spacy-llm. + RETURNS (OpenAI): OpenAI instance for 'text-davinci' model + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=Endpoints.NON_CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + @registry.llm_models("spacy.Text-Davinci.v2") def openai_text_davinci_v2( config: Dict[Any, Any] = SimpleFrozenDict( From 52954003ac75551ced5fc974cc54457bde732624 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 24 Nov 2023 14:45:49 +0100 Subject: [PATCH 29/51] Add sharding test with simple count task. --- spacy_llm/models/rest/anthropic/__init__.py | 17 +++- spacy_llm/models/rest/azure/__init__.py | 4 +- spacy_llm/models/rest/cohere/__init__.py | 4 +- spacy_llm/models/rest/openai/__init__.py | 12 ++- spacy_llm/models/rest/palm/__init__.py | 4 +- spacy_llm/models/rest/palm/registry.py | 43 +++++++++- spacy_llm/pipeline/llm.py | 4 +- spacy_llm/tests/sharding/__init__.py | 0 spacy_llm/tests/sharding/test_sharding.py | 48 ++++++++++++ spacy_llm/tests/sharding/util.py | 87 +++++++++++++++++++++ spacy_llm/ty.py | 20 ++--- 11 files changed, 216 insertions(+), 27 deletions(-) create mode 100644 spacy_llm/tests/sharding/__init__.py create mode 100644 spacy_llm/tests/sharding/test_sharding.py create mode 100644 spacy_llm/tests/sharding/util.py diff --git a/spacy_llm/models/rest/anthropic/__init__.py b/spacy_llm/models/rest/anthropic/__init__.py index 745a0fbe..ca6c99b8 100644 --- a/spacy_llm/models/rest/anthropic/__init__.py +++ b/spacy_llm/models/rest/anthropic/__init__.py @@ -1,15 +1,26 @@ from .model import Anthropic, Endpoints -from .registry import anthropic_claude_1, anthropic_claude_1_0, anthropic_claude_1_2 -from .registry import anthropic_claude_1_3, anthropic_claude_instant_1 -from .registry import anthropic_claude_instant_1_1 +from .registry import anthropic_claude_1, anthropic_claude_1_0, anthropic_claude_1_0_v2 +from .registry import anthropic_claude_1_2, anthropic_claude_1_2_v2 +from .registry import anthropic_claude_1_3, anthropic_claude_1_3_v2 +from .registry import anthropic_claude_1_v2, anthropic_claude_2, anthropic_claude_2_v2 +from .registry import anthropic_claude_instant_1, anthropic_claude_instant_1_1 +from .registry import anthropic_claude_instant_1_1_v2, anthropic_claude_instant_1_v2 __all__ = [ "Anthropic", "Endpoints", "anthropic_claude_1", + "anthropic_claude_1_v2", "anthropic_claude_1_0", + "anthropic_claude_1_0_v2", "anthropic_claude_1_2", + "anthropic_claude_1_2_v2", "anthropic_claude_1_3", + "anthropic_claude_1_3_v2", "anthropic_claude_instant_1", + "anthropic_claude_instant_1_v2", "anthropic_claude_instant_1_1", + "anthropic_claude_instant_1_1_v2", + "anthropic_claude_2", + "anthropic_claude_2_v2", ] diff --git a/spacy_llm/models/rest/azure/__init__.py b/spacy_llm/models/rest/azure/__init__.py index 142972a5..f59e8679 100644 --- a/spacy_llm/models/rest/azure/__init__.py +++ b/spacy_llm/models/rest/azure/__init__.py @@ -1,4 +1,4 @@ from .model import AzureOpenAI -from .registry import azure_openai +from .registry import azure_openai, azure_openai_v2 -__all__ = ["AzureOpenAI", "azure_openai"] +__all__ = ["AzureOpenAI", "azure_openai", "azure_openai_v2"] diff --git a/spacy_llm/models/rest/cohere/__init__.py b/spacy_llm/models/rest/cohere/__init__.py index f5319ec4..8ce0b194 100644 --- a/spacy_llm/models/rest/cohere/__init__.py +++ b/spacy_llm/models/rest/cohere/__init__.py @@ -1,4 +1,4 @@ from .model import Cohere, Endpoints -from .registry import cohere_command +from .registry import cohere_command, cohere_command_v2 -__all__ = ["Cohere", "Endpoints", "cohere_command"] +__all__ = ["Cohere", "Endpoints", "cohere_command", "cohere_command_v2"] diff --git a/spacy_llm/models/rest/openai/__init__.py b/spacy_llm/models/rest/openai/__init__.py index e1782596..3cde8bef 100644 --- a/spacy_llm/models/rest/openai/__init__.py +++ b/spacy_llm/models/rest/openai/__init__.py @@ -2,10 +2,11 @@ from .registry import openai_ada, openai_ada_v2, openai_babbage, openai_babbage_v2 from .registry import openai_code_davinci, openai_code_davinci_v2, openai_curie from .registry import openai_curie_v2, openai_davinci, openai_davinci_v2 -from .registry import openai_gpt_3_5, openai_gpt_3_5_v2, openai_gpt_4, openai_gpt_4_v2 -from .registry import openai_text_ada, openai_text_ada_v2, openai_text_babbage -from .registry import openai_text_babbage_v2, openai_text_curie, openai_text_curie_v2 -from .registry import openai_text_davinci, openai_text_davinci_v2 +from .registry import openai_gpt_3_5, openai_gpt_3_5_v2, openai_gpt_3_5_v3 +from .registry import openai_gpt_4, openai_gpt_4_v2, openai_gpt_4_v3, openai_text_ada +from .registry import openai_text_ada_v2, openai_text_babbage, openai_text_babbage_v2 +from .registry import openai_text_curie, openai_text_curie_v2, openai_text_davinci +from .registry import openai_text_davinci_v2, openai_text_davinci_v3 __all__ = [ "OpenAI", @@ -22,8 +23,10 @@ "openai_davinci_v2", "openai_gpt_3_5", "openai_gpt_3_5_v2", + "openai_gpt_3_5_v3", "openai_gpt_4", "openai_gpt_4_v2", + "openai_gpt_4_v3", "openai_text_ada", "openai_text_ada_v2", "openai_text_babbage", @@ -32,4 +35,5 @@ "openai_text_curie_v2", "openai_text_davinci", "openai_text_davinci_v2", + "openai_text_davinci_v3", ] diff --git a/spacy_llm/models/rest/palm/__init__.py b/spacy_llm/models/rest/palm/__init__.py index 1255be2f..23fe28ec 100644 --- a/spacy_llm/models/rest/palm/__init__.py +++ b/spacy_llm/models/rest/palm/__init__.py @@ -1,4 +1,4 @@ from .model import Endpoints, PaLM -from .registry import palm_bison +from .registry import palm_bison, palm_bison_v2 -__all__ = ["palm_bison", "PaLM", "Endpoints"] +__all__ = ["palm_bison", "palm_bison_v2", "PaLM", "Endpoints"] diff --git a/spacy_llm/models/rest/palm/registry.py b/spacy_llm/models/rest/palm/registry.py index 7ec8e65c..1e68faed 100644 --- a/spacy_llm/models/rest/palm/registry.py +++ b/spacy_llm/models/rest/palm/registry.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable +from typing import Any, Callable, Dict, Iterable, Optional from confection import SimpleFrozenDict @@ -7,14 +7,15 @@ from .model import Endpoints, PaLM -@registry.llm_models("spacy.PaLM.v1") -def palm_bison( +@registry.llm_models("spacy.PaLM.v2") +def palm_bison_v2( config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), name: Literal["chat-bison-001", "text-bison-001"] = "text-bison-001", # noqa: F821 strict: bool = PaLM.DEFAULT_STRICT, max_tries: int = PaLM.DEFAULT_MAX_TRIES, interval: float = PaLM.DEFAULT_INTERVAL, max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, + context_length: Optional[int] = None, ) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: """Returns Google instance for PaLM Bison model using REST to prompt API. name (Literal["chat-bison-001", "text-bison-001"]): Model to use. @@ -31,6 +32,42 @@ def palm_bison( natively provided by spacy-llm. RETURNS (PaLM): PaLM instance for Bison model. """ + return PaLM( + name=name, + endpoint=Endpoints.TEXT.value + if name in {"text-bison-001"} + else Endpoints.MSG.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + context_length=context_length, + ) + + +@registry.llm_models("spacy.PaLM.v1") +def palm_bison( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), + name: Literal["chat-bison-001", "text-bison-001"] = "text-bison-001", # noqa: F821 + strict: bool = PaLM.DEFAULT_STRICT, + max_tries: int = PaLM.DEFAULT_MAX_TRIES, + interval: float = PaLM.DEFAULT_INTERVAL, + max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, +) -> PaLM: + """Returns Google instance for PaLM Bison model using REST to prompt API. + name (Literal["chat-bison-001", "text-bison-001"]): Model to use. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + RETURNS (PaLM): PaLM instance for Bison model. + """ return PaLM( name=name, endpoint=Endpoints.TEXT.value diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 756836b6..6acde4a2 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -16,7 +16,7 @@ from .. import registry # noqa: F401 from ..compat import TypedDict -from ..ty import Cache, LabeledTask, LLMTask, ModelWithContextLength +from ..ty import Cache, LabeledTask, ModelWithContextLength, NonshardingLLMTask from ..ty import PromptExecutorType, ScorableTask, Serializable, ShardingLLMTask from ..ty import supports_sharding, validate_type_consistency @@ -36,7 +36,7 @@ DEFAULT_SAVE_IO = False DEFAULT_VALIDATE_TYPES = True -_LLMTask = Union[LLMTask, ShardingLLMTask] +_LLMTask = Union[NonshardingLLMTask, ShardingLLMTask] class CacheConfigType(TypedDict): diff --git a/spacy_llm/tests/sharding/__init__.py b/spacy_llm/tests/sharding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py new file mode 100644 index 00000000..b5f8ba60 --- /dev/null +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -0,0 +1,48 @@ +import pytest +from confection import Config + +from spacy_llm.tests.compat import has_openai_key +from spacy_llm.util import assemble_from_config + +from .util import ShardingCountTask # noqa: F401 + + +@pytest.fixture +def config(): + return Config().from_str( + """ + [nlp] + lang = "en" + pipeline = ["llm"] + + [components] + + [components.llm] + factory = "llm" + + [components.llm.task] + @llm_tasks = "spacy.CountWithSharding.v1" + + [components.llm.model] + @llm_models = "spacy.GPT-3-5.v3" + context_length = 20 + """ + ) + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) +def test_with_count_task(config, model: str): + """Tests whether tasks shard data as expected.""" + config["components"]["llm"]["model"]["@llm_models"] = model + nlp = assemble_from_config(config) + # todo add tests for sharding correctness checks + nlp("This is a first shot.") + + +@pytest.mark.parametrize("model", ("spacy.GPT-3.5.v3",)) +@pytest.mark.parametrize("task", ("spacy.Lemma.v1",)) +def test_with_all_tasks(config, model: str, task: str): + # todo add task-specific sharding tests in task test files? + pass diff --git a/spacy_llm/tests/sharding/util.py b/spacy_llm/tests/sharding/util.py new file mode 100644 index 00000000..23be74bf --- /dev/null +++ b/spacy_llm/tests/sharding/util.py @@ -0,0 +1,87 @@ +import warnings +from typing import Iterable, List, Optional + +from spacy.tokens import Doc +from spacy.training import Example + +from spacy_llm.compat import Self +from spacy_llm.registry import registry +from spacy_llm.tasks import BuiltinTask +from spacy_llm.tasks.util.sharding import make_shard_mapper +from spacy_llm.ty import FewshotExample, ShardReducer + + +def parse_responses( + task: "ShardingCountTask", + shards: Iterable[Iterable[Doc]], + responses: Iterable[Iterable[str]], +) -> Iterable[Iterable[int]]: + for responses_for_doc, shards_for_doc in zip(responses, shards): + results_for_doc: List[int] = [] + for response, shard in zip(responses_for_doc, shards_for_doc): + results_for_doc.append(int(response)) + + yield results_for_doc + + +def reduce_shards_to_doc(task: "ShardingCountExample", shards: Iterable[Doc]) -> Doc: + shards = list(shards) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping unsupported user data", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) + doc.user_data["count"] = sum([shard.user_data["count"] for shard in shards]) + return doc + + +class ShardingCountExample(FewshotExample): + @classmethod + def generate(cls, example: Example, task: "ShardingCountTask") -> Optional[Self]: + return None + + +@registry.llm_tasks("spacy.CountWithSharding.v1") +class ShardingCountTask(BuiltinTask): + _PROMPT_TEMPLATE = "Reply with the number of characters in this string (and nothing else): '{{ text }}'" + + def __init__(self): + assert isinstance(reduce_shards_to_doc, ShardReducer) + super().__init__( + parse_responses=parse_responses, + prompt_example_type=ShardingCountExample, + template=self._PROMPT_TEMPLATE, + prompt_examples=[], + shard_mapper=make_shard_mapper(), + shard_reducer=reduce_shards_to_doc, + ) + + # def generate_prompts( + # self, docs: Iterable[Doc], context_length: Optional[int] = None + # ) -> Iterable[Tuple[Iterable[Any], Iterable[Doc]]]: + # x = super().generate_prompts(docs, context_length) + # return x + + def parse_responses( + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] + ) -> Iterable[Doc]: + shards_teed = self._tee_2d_iterable(shards, 2) + + for shards_for_doc, counts_for_doc in zip( + shards_teed[0], self._parse_responses(self, shards_teed[1], responses) + ): + shards_for_doc = list(shards_for_doc) + for shard, count in zip(shards_for_doc, counts_for_doc): + shard.user_data["count"] = count + + yield self._shard_reducer(self, shards_for_doc) # type: ignore[arg-type] + + @property + def prompt_template(self) -> str: + return self._PROMPT_TEMPLATE + + @property + def _cfg_keys(self) -> List[str]: + return [] diff --git a/spacy_llm/ty.py b/spacy_llm/ty.py index 0889c1c2..5cf630b0 100644 --- a/spacy_llm/ty.py +++ b/spacy_llm/ty.py @@ -138,14 +138,14 @@ def parse_responses( """ -@runtime_checkable -class LLMTask(Protocol): - generate_prompts: Callable[..., Iterable[Any]] - parse_responses: Callable[..., Iterable[Doc]] +# @runtime_checkable +# class LLMTask(Protocol): +# generate_prompts: Callable[..., Iterable[Any]] +# parse_responses: Callable[..., Iterable[Doc]] TaskContraT = TypeVar( - "TaskContraT", bound=Union[ShardingLLMTask, LLMTask], contravariant=True + "TaskContraT", bound=Union[ShardingLLMTask, NonshardingLLMTask], contravariant=True ) ShardingTaskContraT = TypeVar( "ShardingTaskContraT", bound=ShardingLLMTask, contravariant=True @@ -208,7 +208,9 @@ def clear(self) -> None: class Cache(Protocol): """Defines minimal set of operations a cache implementiation needs to support.""" - def initialize(self, vocab: Vocab, task: Union[LLMTask, ShardingLLMTask]) -> None: + def initialize( + self, vocab: Vocab, task: Union[NonshardingLLMTask, ShardingLLMTask] + ) -> None: """ Initialize cache with data not available at construction time. vocab (Vocab): Vocab object. @@ -325,7 +327,7 @@ def _extract_model_call_signature(model: PromptExecutorType) -> Dict[str, Any]: return signature -def supports_sharding(task: Union[LLMTask, ShardingLLMTask]) -> bool: +def supports_sharding(task: Union[NonshardingLLMTask, ShardingLLMTask]) -> bool: """Determines task type, as isinstance(instance, Protocol) only checks for method names. This also considers argument and return types. Raises an exception if task is neither. Note that this is not as thorough as validate_type_consistency() and relies on clues to determine which task type @@ -343,7 +345,7 @@ def supports_sharding(task: Union[LLMTask, ShardingLLMTask]) -> bool: def validate_type_consistency( - task: Union[LLMTask, ShardingLLMTask], model: PromptExecutorType + task: Union[NonshardingLLMTask, ShardingLLMTask], model: PromptExecutorType ) -> None: """Check whether the types of the task and model signatures match. task (ShardingLLMTask): Specified task. @@ -351,7 +353,7 @@ def validate_type_consistency( """ # Raises an error or prints a warning if something looks wrong/odd. # todo update error messages - if not isinstance(task, LLMTask): + if not isinstance(task, NonshardingLLMTask): raise ValueError( f"A task needs to adhere to the interface of either 'LLMTask' or 'ShardingLLMTask', but {type(task)} " f"doesn't." From 70e364373b749d85d8be11a68c9255155a6602fd Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 24 Nov 2023 17:12:56 +0100 Subject: [PATCH 30/51] Fix sharding algorithm. --- spacy_llm/models/rest/base.py | 6 +++++- spacy_llm/tasks/builtin_task.py | 13 +++++++------ spacy_llm/tasks/util/sharding.py | 10 ++++++++++ spacy_llm/tests/sharding/test_sharding.py | 4 +++- 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index 12bdefdd..df089961 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -86,7 +86,11 @@ def context_length(self) -> Optional[int]: """Returns context length in number of tokens for this model. RETURNS (Optional[int]): Max. number of tokens in allowed in prompt for the current model. None if unknown. """ - return self._get_context_lengths().get(self._name, self._context_length) + return ( + self._context_length + if self._context_length + else self._get_context_lengths().get(self._name, None) # type: ignore[arg-type] + ) @property @abc.abstractmethod diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 4b9a7ddb..0d760f45 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -77,22 +77,23 @@ def render_template(shard: Doc, i_shard: int, i_doc: int, n_shards: int) -> str: RETURNS (str): Rendered template. """ return _template.render( - text=doc.text, + text=shard.text, prompt_examples=self._prompt_examples, **self._get_prompt_data(shard, i_shard, i_doc, n_shards), ) - for i_doc, doc in enumerate(self._preprocess_docs_for_prompt(docs)): + for _i_doc, _doc in enumerate(self._preprocess_docs_for_prompt(docs)): # If no context length provided (e. g. because models don't provide it): don't shard. shards = ( - self._shard_mapper(doc, i_doc, context_length, render_template) + self._shard_mapper(_doc, _i_doc, context_length, render_template) if context_length is not None - else [doc] + else [_doc] ) + shards = list(shards) shards_teed = tee(shards, 3) yield [ - render_template(shard, i_shard, i_doc, len(list(shards_teed[0]))) - for i_shard, shard in enumerate(shards_teed[1]) + render_template(_shard, _i_shard, _i_doc, len(list(shards_teed[0]))) + for _i_shard, _shard in enumerate(shards_teed[1]) ], shards_teed[2] def _get_prompt_data( diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index b5f0f31b..72748f70 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -61,6 +61,7 @@ def map_doc_to_shards( fits_in_context = False shard: Optional[Doc] = None end_idx = -1 + n_tries = 0 while fits_in_context is False: end_idx = start_idx + int(len(remaining_doc) * fraction) @@ -75,6 +76,15 @@ def map_doc_to_shards( <= context_length ) fraction /= 2 + n_tries += 1 + + # If prompt is too large even with shard of a single token, raise error - we can't shard any more + # than this. This is an edge case and will most likely never occur. + if len(shard) == 1 and not fits_in_context: + raise ValueError( + "Prompt size doesn't allow for the inclusion for shard of length 1. Please " + "review your prompt and reduce its size." + ) assert shard is not None shards.append(shard) diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index b5f8ba60..9d24ae35 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -38,7 +38,9 @@ def test_with_count_task(config, model: str): config["components"]["llm"]["model"]["@llm_models"] = model nlp = assemble_from_config(config) # todo add tests for sharding correctness checks - nlp("This is a first shot.") + nlp( + "Do one thing every day that scares you. The only thing we have to fear is fear itself." + ) @pytest.mark.parametrize("model", ("spacy.GPT-3.5.v3",)) From 4321483f06ea9662e921903e62a902d8d346a513 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 27 Nov 2023 11:52:43 +0100 Subject: [PATCH 31/51] Add test with simple count task. --- spacy_llm/models/hf/base.py | 1 - spacy_llm/tests/models/test_dolly.py | 3 -- spacy_llm/tests/sharding/test_sharding.py | 36 ++++++++++++++++------- spacy_llm/tests/sharding/util.py | 10 ++----- 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index 6dc7fdc7..e889097d 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -24,7 +24,6 @@ def __init__( name (str): Name of HF model to load (without account name). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. - inference_config (Dict[Any, Any]): HF config for model run. """ self._name = name if self.hf_account in name else f"{self.hf_account}/{name}" default_cfg_init, default_cfg_run = self.compile_default_configs() diff --git a/spacy_llm/tests/models/test_dolly.py b/spacy_llm/tests/models/test_dolly.py index 4b70179d..14f48afe 100644 --- a/spacy_llm/tests/models/test_dolly.py +++ b/spacy_llm/tests/models/test_dolly.py @@ -25,9 +25,6 @@ [components] -[components.llm] -factory = "llm" -save_io = True [components.llm.task] @llm_tasks = "spacy.NoOp.v1" diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 9d24ae35..621bf72c 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -6,11 +6,13 @@ from .util import ShardingCountTask # noqa: F401 +_CONTEXT_LENGTH = 20 + @pytest.fixture def config(): return Config().from_str( - """ + f""" [nlp] lang = "en" pipeline = ["llm"] @@ -19,13 +21,14 @@ def config(): [components.llm] factory = "llm" + save_io = True [components.llm.task] @llm_tasks = "spacy.CountWithSharding.v1" [components.llm.model] @llm_models = "spacy.GPT-3-5.v3" - context_length = 20 + context_length = {_CONTEXT_LENGTH} """ ) @@ -37,14 +40,27 @@ def test_with_count_task(config, model: str): """Tests whether tasks shard data as expected.""" config["components"]["llm"]["model"]["@llm_models"] = model nlp = assemble_from_config(config) - # todo add tests for sharding correctness checks - nlp( + doc = nlp( "Do one thing every day that scares you. The only thing we have to fear is fear itself." ) - -@pytest.mark.parametrize("model", ("spacy.GPT-3.5.v3",)) -@pytest.mark.parametrize("task", ("spacy.Lemma.v1",)) -def test_with_all_tasks(config, model: str, task: str): - # todo add task-specific sharding tests in task test files? - pass + # With a context length of 20 we expect the doc to be split into five prompts. + prompts = [ + pr.replace('"', "").strip() + for pr in doc.user_data["llm_io"]["llm"]["prompt"][1:-1].split('",') + ] + prompt_texts = [pr[65:].replace("'", "").strip() for pr in prompts] + responses = [ + int(r.replace("'", "")) + for r in doc.user_data["llm_io"]["llm"]["response"][1:-1].split("',") + ] + assert prompt_texts == [ + "Do one thing every day", + "that scares you", + ". The only", + "thing we have to", + "fear is fear itself.", + ] + assert all( + [response == len(pr.split()) for response, pr in zip(responses, prompt_texts)] + ) diff --git a/spacy_llm/tests/sharding/util.py b/spacy_llm/tests/sharding/util.py index 23be74bf..87b21f37 100644 --- a/spacy_llm/tests/sharding/util.py +++ b/spacy_llm/tests/sharding/util.py @@ -45,7 +45,9 @@ def generate(cls, example: Example, task: "ShardingCountTask") -> Optional[Self] @registry.llm_tasks("spacy.CountWithSharding.v1") class ShardingCountTask(BuiltinTask): - _PROMPT_TEMPLATE = "Reply with the number of characters in this string (and nothing else): '{{ text }}'" + _PROMPT_TEMPLATE = ( + "Reply with the number of words in this string (and nothing else): '{{ text }}'" + ) def __init__(self): assert isinstance(reduce_shards_to_doc, ShardReducer) @@ -58,12 +60,6 @@ def __init__(self): shard_reducer=reduce_shards_to_doc, ) - # def generate_prompts( - # self, docs: Iterable[Doc], context_length: Optional[int] = None - # ) -> Iterable[Tuple[Iterable[Any], Iterable[Doc]]]: - # x = super().generate_prompts(docs, context_length) - # return x - def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: From ef6e738d8cb3edd5521eaf80b673fdf8944b629f Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 27 Nov 2023 12:21:43 +0100 Subject: [PATCH 32/51] Add context length as init arg in HF models. --- spacy_llm/models/hf/base.py | 5 ++++- spacy_llm/models/hf/dolly.py | 8 +++----- spacy_llm/models/hf/falcon.py | 16 ++++++++++------ spacy_llm/models/hf/llama2.py | 16 ++++++++++------ spacy_llm/models/hf/mistral.py | 16 ++++++++++------ spacy_llm/models/hf/openllama.py | 16 ++++++++++------ spacy_llm/models/hf/stablelm.py | 16 ++++++++-------- 7 files changed, 55 insertions(+), 38 deletions(-) diff --git a/spacy_llm/models/hf/base.py b/spacy_llm/models/hf/base.py index e889097d..8f061f5d 100644 --- a/spacy_llm/models/hf/base.py +++ b/spacy_llm/models/hf/base.py @@ -17,6 +17,7 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): """Initializes HF model instance. query (Callable[[Any, Iterable[Any]], Iterable[Any]): Callable executing LLM prompts when @@ -24,8 +25,10 @@ def __init__( name (str): Name of HF model to load (without account name). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. + context_length (int): Context length for this model. Necessary for sharding. """ self._name = name if self.hf_account in name else f"{self.hf_account}/{name}" + self._context_length = context_length default_cfg_init, default_cfg_run = self.compile_default_configs() self._config_init, self._config_run = default_cfg_init, default_cfg_run @@ -93,11 +96,11 @@ def get_model_names(cls) -> Tuple[str, ...]: return tuple(str(arg) for arg in cls.MODEL_NAMES.__args__) # type: ignore[attr-defined] @property - @abc.abstractmethod def context_length(self) -> int: """Returns context length in number of tokens for this model. RETURNS (int): Max. number of tokens in allowed in prompt for the current model. """ + return self._context_length @property @abc.abstractmethod diff --git a/spacy_llm/models/hf/dolly.py b/spacy_llm/models/hf/dolly.py index fc15dc67..95b2bc9a 100644 --- a/spacy_llm/models/hf/dolly.py +++ b/spacy_llm/models/hf/dolly.py @@ -50,10 +50,6 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run, ) - @property - def context_length(self) -> int: - return 2048 - @registry.llm_models("spacy.Dolly.v1") def dolly_hf( @@ -68,4 +64,6 @@ def dolly_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Dolly instance that can execute a set of prompts and return the raw responses. """ - return Dolly(name=name, config_init=config_init, config_run=config_run) + return Dolly( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/falcon.py b/spacy_llm/models/hf/falcon.py index 8d4f17b1..b8d299db 100644 --- a/spacy_llm/models/hf/falcon.py +++ b/spacy_llm/models/hf/falcon.py @@ -17,9 +17,15 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) self._config_run["pad_token_id"] = self._tokenizer.pad_token_id @@ -67,10 +73,6 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_run, ) - @property - def context_length(self) -> int: - return 2048 - @registry.llm_models("spacy.Falcon.v1") def falcon_hf( @@ -85,4 +87,6 @@ def falcon_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return the raw responses. """ - return Falcon(name=name, config_init=config_init, config_run=config_run) + return Falcon( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/llama2.py b/spacy_llm/models/hf/llama2.py index ab5e1063..c76bb7ac 100644 --- a/spacy_llm/models/hf/llama2.py +++ b/spacy_llm/models/hf/llama2.py @@ -17,8 +17,14 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) # Instantiate GenerationConfig object from config dict. self._hf_config_run = transformers.GenerationConfig.from_pretrained( self._name, @@ -54,10 +60,6 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: return HuggingFace.compile_default_configs() - @property - def context_length(self) -> int: - return 4096 - @registry.llm_models("spacy.Llama2.v1") def llama2_hf( @@ -72,4 +74,6 @@ def llama2_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Llama2 instance that can execute a set of prompts and return the raw responses. """ - return Llama2(name=name, config_init=config_init, config_run=config_run) + return Llama2( + name=name, config_init=config_init, config_run=config_run, context_length=4096 + ) diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 918ad6ce..7da0df1a 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -15,10 +15,16 @@ def __init__( name: MODEL_NAMES, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_instruct = "instruct" in name - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) @@ -82,10 +88,6 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: return responses - @property - def context_length(self) -> int: - return 8000 - @registry.llm_models("spacy.Mistral.v1") def mistral_hf( @@ -100,4 +102,6 @@ def mistral_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return the raw responses. """ - return Mistral(name=name, config_init=config_init, config_run=config_run) + return Mistral( + name=name, config_init=config_init, config_run=config_run, context_length=8000 + ) diff --git a/spacy_llm/models/hf/openllama.py b/spacy_llm/models/hf/openllama.py index cc166741..391e4361 100644 --- a/spacy_llm/models/hf/openllama.py +++ b/spacy_llm/models/hf/openllama.py @@ -20,9 +20,15 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) def init_model(self) -> "transformers.AutoModelForCausalLM": """Sets up HF model and needed utilities. @@ -85,10 +91,6 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: {**default_cfg_run, "max_new_tokens": 32}, ) - @property - def context_length(self) -> int: - return 2048 - @registry.llm_models("spacy.OpenLLaMA.v1") def openllama_hf( @@ -103,4 +105,6 @@ def openllama_hf( RETURNS (Callable[[Iterable[str]], Iterable[str]]): OpenLLaMA instance that can execute a set of prompts and return the raw responses. """ - return OpenLLaMA(name=name, config_init=config_init, config_run=config_run) + return OpenLLaMA( + name=name, config_init=config_init, config_run=config_run, context_length=2048 + ) diff --git a/spacy_llm/models/hf/stablelm.py b/spacy_llm/models/hf/stablelm.py index 14eec8bb..e5b7c36e 100644 --- a/spacy_llm/models/hf/stablelm.py +++ b/spacy_llm/models/hf/stablelm.py @@ -39,10 +39,16 @@ def __init__( name: str, config_init: Optional[Dict[str, Any]], config_run: Optional[Dict[str, Any]], + context_length: int, ): self._tokenizer: Optional["transformers.AutoTokenizer"] = None self._is_tuned = "tuned" in name - super().__init__(name=name, config_init=config_init, config_run=config_run) + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) def init_model(self) -> "transformers.AutoModelForCausalLM": """Sets up HF model and needed utilities. @@ -115,10 +121,6 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: }, ) - @property - def context_length(self) -> int: - return 4096 - @registry.llm_models("spacy.StableLM.v1") def stablelm_hf( @@ -138,7 +140,5 @@ def stablelm_hf( f"Expected one of {StableLM.get_model_names()}, but received {name}." ) return StableLM( - name=name, - config_init=config_init, - config_run=config_run, + name=name, config_init=config_init, config_run=config_run, context_length=4096 ) From e3ff37dd643c879fe4a63171f15f595f66263ed0 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 28 Nov 2023 17:25:57 +0100 Subject: [PATCH 33/51] Fix tests. Don't stringify IO lists if sharded. --- spacy_llm/pipeline/llm.py | 15 ++++-- spacy_llm/tasks/lemma/task.py | 1 + spacy_llm/tests/pipeline/test_llm.py | 4 +- spacy_llm/tests/sharding/test_sharding.py | 58 ++++++++++++++++----- spacy_llm/tests/tasks/test_entity_linker.py | 4 +- 5 files changed, 61 insertions(+), 21 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 6acde4a2..e27b1480 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -278,10 +278,17 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: ) llm_io = doc.user_data["llm_io"][self._name] next_prompt = next(prompts_iters[-1]) - llm_io["prompt"] = str( - next_prompt[0] if has_shards else next_prompt - ) - llm_io["response"] = str(next(responses_iters[-1])) + if has_shards: + llm_io["prompt"] = [ + str(shard_prompt) for shard_prompt in next_prompt[0] + ] + llm_io["response"] = [ + str(shard_response) + for shard_response in next(responses_iters[-1]) + ] + else: + llm_io["prompt"] = str(next_prompt) + llm_io["response"] = str(next(responses_iters[-1])) self._cache.add(doc) final_docs.append(doc) diff --git a/spacy_llm/tasks/lemma/task.py b/spacy_llm/tasks/lemma/task.py index c24d82c1..add263d2 100644 --- a/spacy_llm/tasks/lemma/task.py +++ b/spacy_llm/tasks/lemma/task.py @@ -58,6 +58,7 @@ def parse_responses( # match. if len(tokens) != len(lemmas): updated_shards_for_doc.append(shard) + continue # Assign lemmas. for token, lemma_info in zip(tokens, lemmas): diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 5c57def0..c8c3890c 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -73,8 +73,8 @@ def test_llm_pipe(noop_config: Dict[str, Any], n_process: int, shard: bool): for doc in docs: llm_io = doc.user_data["llm_io"] - assert llm_io["llm"]["prompt"] == str([_NOOP_PROMPT] if shard else _NOOP_PROMPT) - assert llm_io["llm"]["response"] == str([_NOOP_RESPONSE]) + assert llm_io["llm"]["prompt"] == [_NOOP_PROMPT] if shard else _NOOP_PROMPT + assert llm_io["llm"]["response"] == [_NOOP_RESPONSE] @pytest.mark.parametrize("n_process", [1, 2]) diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 621bf72c..24be10cc 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -36,7 +36,7 @@ def config(): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") @pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) -def test_with_count_task(config, model: str): +def test_sharding_count(config, model: str): """Tests whether tasks shard data as expected.""" config["components"]["llm"]["model"]["@llm_models"] = model nlp = assemble_from_config(config) @@ -45,22 +45,54 @@ def test_with_count_task(config, model: str): ) # With a context length of 20 we expect the doc to be split into five prompts. + marker = "(and nothing else): '" prompts = [ - pr.replace('"', "").strip() - for pr in doc.user_data["llm_io"]["llm"]["prompt"][1:-1].split('",') + pr[pr.index(marker) + len(marker) : -1] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] ] - prompt_texts = [pr[65:].replace("'", "").strip() for pr in prompts] - responses = [ - int(r.replace("'", "")) - for r in doc.user_data["llm_io"]["llm"]["response"][1:-1].split("',") - ] - assert prompt_texts == [ - "Do one thing every day", + responses = [int(r) for r in doc.user_data["llm_io"]["llm"]["response"]] + assert prompts == [ + "Do one thing every day ", "that scares you", - ". The only", - "thing we have to", + ". The only ", + "thing we have to ", "fear is fear itself.", ] assert all( - [response == len(pr.split()) for response, pr in zip(responses, prompt_texts)] + [response == len(pr.split()) for response, pr in zip(responses, prompts)] ) + assert sum(responses) == doc.user_data["count"] + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) +def test_sharding_lemma(config, model: str): + """Tests whether tasks shard data as expected.""" + context_length = 120 + config["components"]["llm"]["model"]["@llm_models"] = model + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Lemma.v1"} + + text = ( + "Do one thing every day that scares you. The only thing we have to fear is fear itself. Do one thing every " + "day that scares you. The only thing we have to fear is fear itself. " + ) + nlp = assemble_from_config(config) + doc = nlp(text) + + # With a context length of 120 we expect the doc to be split into four prompts. + marker = "to be lemmatized:\n'''\n" + prompts = [ + pr[pr.index(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + # Make sure lemmas are set (somme might not be because the LLM didn't return parsable a response). + assert any([t.lemma != 0 for t in doc]) + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is ", + "fear itself. Do one thing every day that scares you", + ". The only thing we have to fear is fear itself. ", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 4 diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index fdfc80e4..ac74e02d 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -352,8 +352,8 @@ def make_doc() -> Doc: nlp.components[0][1]._task._auto_nil = False doc = nlp(make_doc()) assert ( - f"- For *Foo*:n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" - in doc.user_data["llm_io"]["llm"]["prompt"].replace("\\", "") + f"- For *Foo*:\n {EntityLinker.NIL}. {UNAVAILABLE_ENTITY_DESC}" + in doc.user_data["llm_io"]["llm"]["prompt"][0] ) assert doc.ents[0].kb_id_ == EntityLinker.NIL # Sometimes GPT-3.5 doesn't manage to include the NIL prediction, in which case all entities are set to NIL. From 056730a63f847ce9c848af1c8946d643e72647f4 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 10:33:21 +0100 Subject: [PATCH 34/51] Fix tests. --- spacy_llm/pipeline/llm.py | 5 ++++- spacy_llm/tests/pipeline/test_llm.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index e27b1480..c051e889 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -288,7 +288,10 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: ] else: llm_io["prompt"] = str(next_prompt) - llm_io["response"] = str(next(responses_iters[-1])) + # Models always return nested responses. For non-sharding tasks this will always be a 1-list. + x = next(responses_iters[-1])[0] + llm_io["response"] = str(x) + x = 3 self._cache.add(doc) final_docs.append(doc) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index c8c3890c..0742dde9 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -73,8 +73,10 @@ def test_llm_pipe(noop_config: Dict[str, Any], n_process: int, shard: bool): for doc in docs: llm_io = doc.user_data["llm_io"] - assert llm_io["llm"]["prompt"] == [_NOOP_PROMPT] if shard else _NOOP_PROMPT - assert llm_io["llm"]["response"] == [_NOOP_RESPONSE] + assert llm_io["llm"]["prompt"] == ([_NOOP_PROMPT] if shard else _NOOP_PROMPT) + assert llm_io["llm"]["response"] == ( + [_NOOP_RESPONSE] if shard else _NOOP_RESPONSE + ) @pytest.mark.parametrize("n_process", [1, 2]) From 196c235d4c4490532b8783f2b19bc75a8b184559 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 10:53:59 +0100 Subject: [PATCH 35/51] Add NER sharding test. --- spacy_llm/pipeline/llm.py | 4 +- spacy_llm/tests/sharding/test_sharding.py | 52 +++++++++++++++-------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index c051e889..85490f4e 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -289,9 +289,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: else: llm_io["prompt"] = str(next_prompt) # Models always return nested responses. For non-sharding tasks this will always be a 1-list. - x = next(responses_iters[-1])[0] - llm_io["response"] = str(x) - x = 3 + llm_io["response"] = str(next(responses_iters[-1])[0]) self._cache.add(doc) final_docs.append(doc) diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 24be10cc..2fdf23c5 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -7,6 +7,7 @@ from .util import ShardingCountTask # noqa: F401 _CONTEXT_LENGTH = 20 +_TEXT = "Do one thing every day that scares you. The only thing we have to fear is fear itself." @pytest.fixture @@ -37,14 +38,11 @@ def config(): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") @pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) def test_sharding_count(config, model: str): - """Tests whether tasks shard data as expected.""" + """Tests whether task shards data as expected.""" config["components"]["llm"]["model"]["@llm_models"] = model nlp = assemble_from_config(config) - doc = nlp( - "Do one thing every day that scares you. The only thing we have to fear is fear itself." - ) - # With a context length of 20 we expect the doc to be split into five prompts. + doc = nlp(_TEXT) marker = "(and nothing else): '" prompts = [ pr[pr.index(marker) + len(marker) : -1] @@ -68,20 +66,14 @@ def test_sharding_count(config, model: str): @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") @pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) def test_sharding_lemma(config, model: str): - """Tests whether tasks shard data as expected.""" + """Tests whether task shards data as expected.""" context_length = 120 config["components"]["llm"]["model"]["@llm_models"] = model config["components"]["llm"]["model"]["context_length"] = context_length config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Lemma.v1"} - - text = ( - "Do one thing every day that scares you. The only thing we have to fear is fear itself. Do one thing every " - "day that scares you. The only thing we have to fear is fear itself. " - ) nlp = assemble_from_config(config) - doc = nlp(text) - # With a context length of 120 we expect the doc to be split into four prompts. + doc = nlp(_TEXT) marker = "to be lemmatized:\n'''\n" prompts = [ pr[pr.index(marker) + len(marker) : -4] @@ -91,8 +83,34 @@ def test_sharding_lemma(config, model: str): assert any([t.lemma != 0 for t in doc]) assert prompts == [ "Do one thing every day that scares you. The ", - "only thing we have to fear is ", - "fear itself. Do one thing every day that scares you", - ". The only thing we have to fear is fear itself. ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) +def test_sharding_ner(config, model: str): + """Tests whether task shards data as expected.""" + context_length = 265 + config["components"]["llm"]["model"]["@llm_models"] = model + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.NER.v3", + "labels": ["LOCATION"], + } + nlp = assemble_from_config(config) + + doc = nlp(_TEXT + " Paris is a city.") + marker = "Paragraph: " + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) + assert prompts == [ + "Do one thing every day that scares you. The only thing ", + "we have to fear is fear itself. Paris is a city.", ] - assert len(doc.user_data["llm_io"]["llm"]["response"]) == 4 + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 From 1f51a4add926a1c1d3f08e0122e71b2e520631f2 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 14:26:28 +0100 Subject: [PATCH 36/51] Add REL and sentiment sharding tests. --- spacy_llm/tasks/rel/task.py | 76 ++++++++++++++--------- spacy_llm/tests/sharding/test_sharding.py | 65 ++++++++++++++++--- 2 files changed, 103 insertions(+), 38 deletions(-) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index b86e5ad9..a6683c5c 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union from spacy.language import Language from spacy.tokens import Doc, Span @@ -60,24 +60,7 @@ def __init__( self._field = "rel" def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: - preprocessed_docs: List[Doc] = [] - - for doc in docs: - preprocessed_docs.append( - Doc(doc.vocab, words=RELTask._preannotate(doc).split()) - ) - preprocessed_docs[-1].ents = [ - Span( - preprocessed_docs[-1], - ent.start, - ent.end, - label=ent.label_, - kb_id=ent.kb_id_, - ) - for ent in doc.ents - ] - - return preprocessed_docs + return [RELTask._preannotate(doc, True) for doc in docs] def _get_prompt_data( self, shard: Doc, i_shard: int, i_doc: int, n_shards: int @@ -89,24 +72,61 @@ def _get_prompt_data( } @staticmethod - def _preannotate(doc: Union[Doc, FewshotExample]) -> str: - """Creates a text version of the document with annotated entities.""" - offset = 0 - text = doc.text + def _preannotate( + doc: Union[Doc, FewshotExample], return_as_doc: bool = False + ) -> Union[str, Doc]: + """Creates a text version of the document with annotated entities. + doc (Union[Doc, FewshotExample]): Doc to preannotate. + return_as_doc (bool): Whether to return as doc (by default returned as text). + """ + words: List[str] = [] if len(doc.ents) else [t.text for t in doc] + spaces: List[bool] = [] if len(doc.ents) else [t.whitespace_ != "" for t in doc] + ent_indices: List[Tuple[int, int]] = [] if not hasattr(doc, "ents"): raise ValueError( "Prompt example type used in RELTask has to expose entities via an .ents attribute." ) + # Update token information for doc reconstruction. + last_ent_end = -1 for i, ent in enumerate(doc.ents): - end = ent.end_char - before, after = text[: end + offset], text[end + offset :] annotation = f"[ENT{i}:{ent.label_ if isinstance(doc, Doc) else ent.label}]" - offset += len(annotation) - text = f"{before}{annotation}{after}" + tokens_since_last_ent = [ + *[t for t in doc if last_ent_end <= t.i < ent.start], + *[t for t in ent], + ] + words.extend([*[t.text for t in tokens_since_last_ent], annotation]) + spaces.extend([t.whitespace_ != "" for t in tokens_since_last_ent]) + + # Adjust spaces w.r.t. added annotations, which should appear directly after entity. + spaces.append(spaces[-1]) + spaces[-2] = False + ent_indices.append((ent.start + i, ent.end + i)) + last_ent_end = ent.end + + # Include chars after last ent. + if len(doc.ents): + tokens_since_last_ent = [t for t in doc if last_ent_end <= t.i] + words.extend([t.text for t in tokens_since_last_ent]) + spaces.extend([t.whitespace_ != "" for t in tokens_since_last_ent]) + + # Reconstruct doc. + annotated_doc = Doc(words=words, spaces=spaces, vocab=doc.vocab) + annotated_doc.ents = [ + Span( # noqa: E731 + doc=annotated_doc, + start=ent_idx[0], + end=ent_idx[1], + label=doc.ents[i].label, + vector=doc.ents[i].vector, + vector_norm=doc.ents[i].vector_norm, + kb_id=doc.ents[i].kb_id_, + ) + for i, ent_idx in enumerate(ent_indices) + ] - return text + return annotated_doc.text if not return_as_doc else annotated_doc def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 2fdf23c5..0f964145 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -1,3 +1,5 @@ +import numbers + import pytest from confection import Config @@ -36,10 +38,8 @@ def config(): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) def test_sharding_count(config, model: str): """Tests whether task shards data as expected.""" - config["components"]["llm"]["model"]["@llm_models"] = model nlp = assemble_from_config(config) doc = nlp(_TEXT) @@ -64,11 +64,8 @@ def test_sharding_count(config, model: str): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) -def test_sharding_lemma(config, model: str): - """Tests whether task shards data as expected.""" +def test_sharding_lemma(config): context_length = 120 - config["components"]["llm"]["model"]["@llm_models"] = model config["components"]["llm"]["model"]["context_length"] = context_length config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Lemma.v1"} nlp = assemble_from_config(config) @@ -90,11 +87,8 @@ def test_sharding_lemma(config, model: str): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("model", ("spacy.GPT-3-5.v3",)) -def test_sharding_ner(config, model: str): - """Tests whether task shards data as expected.""" +def test_sharding_ner(config): context_length = 265 - config["components"]["llm"]["model"]["@llm_models"] = model config["components"]["llm"]["model"]["context_length"] = context_length config["components"]["llm"]["task"] = { "@llm_tasks": "spacy.NER.v3", @@ -114,3 +108,54 @@ def test_sharding_ner(config, model: str): "we have to fear is fear itself. Paris is a city.", ] assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_rel(config): + context_length = 100 + config["nlp"]["pipeline"] = ["ner", "llm"] + config["components"]["ner"] = {"source": "en_core_web_md"} + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.REL.v1", + "labels": "LivesIn,Visits", + } + config["initialize"] = {"vectors": "en_core_web_md"} + nlp = assemble_from_config(config) + + doc = nlp("Joey rents a place in New York City, which is in North America.") + marker = "Text:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) + assert hasattr(doc._, "rel") and len(doc._.rel) + assert prompts == [ + "Joey[ENT0:PERSON] rents a place in New York City", + "[ENT1:GPE], which is in North America[ENT2:LOC].", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_sentiment(config): + context_length = 50 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Sentiment.v1"} + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "Text:\n'''\n" + prompts = [ + pr[pr.index(marker) + len(marker) : pr.rindex("\n'''\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert hasattr(doc._, "sentiment") and isinstance(doc._.sentiment, numbers.Number) + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 From e18b30245053d39d54a5ff7dc165a1f73d03001d Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 14:40:20 +0100 Subject: [PATCH 37/51] Add summary sharding tests. --- spacy_llm/tests/sharding/test_sharding.py | 50 ++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 0f964145..54f5dc07 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -38,7 +38,7 @@ def config(): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -def test_sharding_count(config, model: str): +def test_sharding_count(config): """Tests whether task shards data as expected.""" nlp = assemble_from_config(config) @@ -159,3 +159,51 @@ def test_sharding_sentiment(config): "only thing we have to fear is fear itself.", ] assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_spancat(config): + context_length = 265 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.SpanCat.v3", + "labels": ["LOCATION"], + } + nlp = assemble_from_config(config) + + doc = nlp(_TEXT + " Paris is a city.") + marker = "Paragraph: " + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\nAnswer:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.spans.data["sc"]) + assert prompts == [ + "Do one thing every day that ", + "scares you. The only thing we have to ", + "fear is fear itself. Paris is a city.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 3 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_summary(config): + context_length = 50 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.Summarization.v1"} + nlp = assemble_from_config(config) + + doc = nlp(_TEXT) + marker = "needs to be summarized:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\n'''\nSummary:")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert hasattr(doc._, "summary") and doc._.summary + assert prompts == [ + "Do one thing every day that scares you. The ", + "only thing we have to fear is fear itself.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 From 7c092ca1ee871707fbab808ecea5a1d6fe223686 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 16:12:55 +0100 Subject: [PATCH 38/51] Add EL sharding task. Fix bug in shard mapper. --- spacy_llm/tasks/builtin_task.py | 7 +-- spacy_llm/tasks/entity_linker/task.py | 13 +++- spacy_llm/tasks/util/sharding.py | 8 +-- spacy_llm/tests/sharding/test_sharding.py | 75 +++++++++++++++++++++++ 4 files changed, 92 insertions(+), 11 deletions(-) diff --git a/spacy_llm/tasks/builtin_task.py b/spacy_llm/tasks/builtin_task.py index 0d760f45..82a182dd 100644 --- a/spacy_llm/tasks/builtin_task.py +++ b/spacy_llm/tasks/builtin_task.py @@ -90,11 +90,10 @@ def render_template(shard: Doc, i_shard: int, i_doc: int, n_shards: int) -> str: else [_doc] ) shards = list(shards) - shards_teed = tee(shards, 3) yield [ - render_template(_shard, _i_shard, _i_doc, len(list(shards_teed[0]))) - for _i_shard, _shard in enumerate(shards_teed[1]) - ], shards_teed[2] + render_template(_shard, _i_shard, _i_doc, len(shards)) + for _i_shard, _shard in enumerate(shards) + ], shards def _get_prompt_data( self, shard: Doc, i_shard: int, i_doc: int, n_shards: int diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index 76437e64..86426ed0 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -53,6 +53,7 @@ def __init__( self._ents_cands_by_doc: List[List[List[Entity]]] = [] self._has_ent_cands_by_shard: List[List[List[bool]]] = [] self._ents_cands_by_shard: List[List[List[List[Entity]]]] = [] + self._n_shards: Optional[int] = None def initialize( self, @@ -103,6 +104,7 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: # update it here, as we don't know yet how the shards will look like. self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._n_shards = None return [ EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) @@ -136,6 +138,13 @@ def _find_entity_candidates( def _get_prompt_data( self, shard: Doc, i_shard: int, i_doc: int, n_shards: int ) -> Dict[str, Any]: + # n_shards changes before reset happens in _preprocess_docs() whenever sharding mechanism varies number of + # shards. In this case we have to reset task state as well. + if n_shards != self._n_shards: + self._n_shards = n_shards + self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] + # It's not ideal that we have to run candidate selection again here - but due to (1) us wanting to know whether # all entities have candidates before sharding and, more importantly, (2) some entities maybe being split up in # the sharding process it's cleaner to look for candidates again. @@ -151,8 +160,8 @@ def _get_prompt_data( # Update shard-wise candidate info so it can be reused during parsing. if len(self._ents_cands_by_shard[i_doc]) == 0: - self._ents_cands_by_shard[i_doc] = [[] * n_shards] - self._has_ent_cands_by_shard[i_doc] = [[] * n_shards] + self._ents_cands_by_shard[i_doc] = [[] for _ in range(n_shards)] + self._has_ent_cands_by_shard[i_doc] = [[] for _ in range(n_shards)] self._ents_cands_by_shard[i_doc][i_shard] = ents_cands self._has_ent_cands_by_shard[i_doc][i_shard] = has_cands diff --git a/spacy_llm/tasks/util/sharding.py b/spacy_llm/tasks/util/sharding.py index 72748f70..4b9c9a9a 100644 --- a/spacy_llm/tasks/util/sharding.py +++ b/spacy_llm/tasks/util/sharding.py @@ -56,6 +56,7 @@ def map_doc_to_shards( remaining_doc: Optional[Doc] = doc.copy() fraction = 0.5 start_idx = 0 + n_shards = 1 while remaining_doc is not None: fits_in_context = False @@ -67,11 +68,7 @@ def map_doc_to_shards( end_idx = start_idx + int(len(remaining_doc) * fraction) shard = doc[start_idx:end_idx].as_doc(copy_user_data=True) fits_in_context = ( - n_tok_est( - render_template( - shard, len(shards), i_doc, int(1 / fraction) - ) - ) + n_tok_est(render_template(shard, len(shards), i_doc, n_shards)) * buffer_frac <= context_length ) @@ -89,6 +86,7 @@ def map_doc_to_shards( assert shard is not None shards.append(shard) fraction = 1 + n_shards = max(len(shards) + round(1 / fraction), 1) start_idx = end_idx # Set remaining_doc to None if shard contains all of it, i. e. entire original doc has been processed. remaining_doc = ( diff --git a/spacy_llm/tests/sharding/test_sharding.py b/spacy_llm/tests/sharding/test_sharding.py index 54f5dc07..19fc17c4 100644 --- a/spacy_llm/tests/sharding/test_sharding.py +++ b/spacy_llm/tests/sharding/test_sharding.py @@ -1,7 +1,10 @@ import numbers +from pathlib import Path import pytest from confection import Config +from spacy.pipeline import EntityLinker +from spacy.tokens import Span from spacy_llm.tests.compat import has_openai_key from spacy_llm.util import assemble_from_config @@ -207,3 +210,75 @@ def test_sharding_summary(config): "only thing we have to fear is fear itself.", ] assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_textcat(config): + context_length = 100 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = { + "@llm_tasks": "spacy.TextCat.v3", + "labels": "RECIPE", + "exclusive_classes": True, + } + nlp = assemble_from_config(config) + + doc = nlp( + "Fry an egg in a pan. Scramble it. Add some salt, pepper and truffle oil." + ) + marker = "Text:\n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : -4] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.cats) == 1 and "RECIPE" in doc.cats + assert prompts == [ + "Fry an egg in ", + "a pan. Scramble it. Add ", + "some salt, pepper and truffle oil.", + ] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 3 + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sharding_entity_linker(config): + context_length = 290 + config["components"]["llm"]["model"]["context_length"] = context_length + config["components"]["llm"]["task"] = {"@llm_tasks": "spacy.EntityLinker.v1"} + config["initialize"] = { + "components": { + "llm": { + "candidate_selector": { + "@llm_misc": "spacy.CandidateSelector.v1", + "kb_loader": { + "@llm_misc": "spacy.KBFileLoader.v1", + "path": "${paths.el_kb}", + }, + } + } + } + } + config["paths"] = { + "el_kb": str( + Path(__file__).resolve().parent.parent / "tasks" / "misc" / "el_kb_data.yml" + ) + } + nlp = assemble_from_config(config) + + doc = nlp.make_doc("Alice goes to Boston to see the Boston Celtics game.") + doc.ents = [ + Span(doc=doc, start=3, end=4, label="LOC"), # Q100 + Span(doc=doc, start=7, end=9, label="ORG"), # Q131371 + ] + doc = nlp(doc) + marker = "TEXT: \n'''\n" + prompts = [ + pr[pr.rindex(marker) + len(marker) : pr.rindex("\n'''")] + for pr in doc.user_data["llm_io"]["llm"]["prompt"] + ] + assert len(doc.ents) == 2 + assert all([ent.kb_id_ != EntityLinker.NIL for ent in doc.ents]) + assert prompts == ["Alice goes to *Boston* to ", "see the *Boston Celtics* game."] + assert len(doc.user_data["llm_io"]["llm"]["response"]) == 2 From 358ba7217b78687d49763159af62e64fc1a11c4e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 17:51:44 +0100 Subject: [PATCH 39/51] Fix REL error with RELExample parsing. --- spacy_llm/tasks/rel/task.py | 7 +++++ spacy_llm/tasks/rel/util.py | 47 +++++++++++++++++++++++++++++-- spacy_llm/tests/tasks/test_rel.py | 2 +- 3 files changed, 53 insertions(+), 3 deletions(-) diff --git a/spacy_llm/tasks/rel/task.py b/spacy_llm/tasks/rel/task.py index a6683c5c..1624455f 100644 --- a/spacy_llm/tasks/rel/task.py +++ b/spacy_llm/tasks/rel/task.py @@ -83,6 +83,12 @@ def _preannotate( spaces: List[bool] = [] if len(doc.ents) else [t.whitespace_ != "" for t in doc] ent_indices: List[Tuple[int, int]] = [] + # Convert RELExample into Doc for easier subsequent processing. + # todo Solve import cycle so we can expect RELExample here. + if not isinstance(doc, Doc): + assert hasattr(doc, "to_doc") and callable(doc.to_doc) + doc = doc.to_doc() + if not hasattr(doc, "ents"): raise ValueError( "Prompt example type used in RELTask has to expose entities via an .ents attribute." @@ -103,6 +109,7 @@ def _preannotate( spaces.append(spaces[-1]) spaces[-2] = False ent_indices.append((ent.start + i, ent.end + i)) + last_ent_end = ent.end # Include chars after last ent. diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index 68392ca6..79b05bd8 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -1,7 +1,8 @@ import warnings -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Tuple -from spacy.tokens import Doc +from spacy import Vocab +from spacy.tokens import Doc, Span from spacy.training import Example from ...compat import Self @@ -32,6 +33,48 @@ def generate(cls, example: Example, task: RELTask) -> Optional[Self]: relations=example.reference._.rel, ) + def to_doc(self) -> Doc: + """Returns Doc representation of example instance. Note that relations are in user_data["rel"]. + field (str): Doc field to store relations in. + RETURNS (Doc): Representation as doc. + """ + text = self.text + punct_chars = (",", ";", ":", ".", "!", "?") + for punct in punct_chars: + text = text.replace(punct, f" {punct} ") + doc_words = text.split() + doc_spaces = [ + i < len(doc_words) - 1 and doc_words[i + 1] not in punct_chars + for i, word in enumerate(doc_words) + ] + doc = Doc(words=doc_words, spaces=doc_spaces, vocab=Vocab(strings=doc_words)) + + # Set entities after finding correct indices. + conv_ent_indices: List[Tuple[int, int]] = [] + if len(self.ents): + ent_idx = 0 + for token in doc: + if token.idx == self.ents[ent_idx].start_char: + conv_ent_indices.append((token.i, -1)) + if token.idx + len(token.text) == self.ents[ent_idx].end_char: + conv_ent_indices[-1] = (conv_ent_indices[-1][0], token.i + 1) + ent_idx += 1 + if ent_idx == len(self.ents): + break + + doc.ents = [ + Span( # noqa: E731 + doc=doc, + start=ent_idx[0], + end=ent_idx[1], + label=self.ents[i].label, + ) + for i, ent_idx in enumerate(conv_ent_indices) + ] + doc.user_data["rel"] = self.relations + + return doc + def reduce_shards_to_doc(task: RELTask, shards: Iterable[Doc]) -> Doc: """Reduces shards to docs for RELTask. diff --git a/spacy_llm/tests/tasks/test_rel.py b/spacy_llm/tests/tasks/test_rel.py index 1b313380..1734a6ee 100644 --- a/spacy_llm/tests/tasks/test_rel.py +++ b/spacy_llm/tests/tasks/test_rel.py @@ -135,7 +135,7 @@ def test_rel_config(cfg_string, request: FixtureRequest): @pytest.mark.external @pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") -@pytest.mark.parametrize("cfg_string", ["zeroshot_cfg_string", "fewshot_cfg_string"]) +@pytest.mark.parametrize("cfg_string", ["fewshot_cfg_string"]) def test_rel_predict(task, cfg_string, request): """Use OpenAI to get REL results. Note that this test may fail randomly, as the LLM's output is unguaranteed to be consistent/predictable From 0c96fb64150ca3dc638b2db8ac6b34e9b10eac68 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 29 Nov 2023 18:17:01 +0100 Subject: [PATCH 40/51] Use regex for punctuation in REL conversion. --- spacy_llm/tasks/rel/util.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index 79b05bd8..d45298c4 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -1,3 +1,4 @@ +import re import warnings from typing import Iterable, List, Optional, Tuple @@ -38,13 +39,12 @@ def to_doc(self) -> Doc: field (str): Doc field to store relations in. RETURNS (Doc): Representation as doc. """ - text = self.text - punct_chars = (",", ";", ":", ".", "!", "?") - for punct in punct_chars: - text = text.replace(punct, f" {punct} ") + punct_chars_pattern = r'[]!"$%&\'()*+,./:;=#@?[\\^_`{|}~-]+' + text = re.sub(punct_chars_pattern, r" \g<0> ", self.text) doc_words = text.split() doc_spaces = [ - i < len(doc_words) - 1 and doc_words[i + 1] not in punct_chars + i < len(doc_words) - 1 + and not re.match(punct_chars_pattern, doc_words[i + 1]) for i, word in enumerate(doc_words) ] doc = Doc(words=doc_words, spaces=doc_spaces, vocab=Vocab(strings=doc_words)) From dc926bd086407501598cc9a2cce29fcd265b775b Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 1 Dec 2023 11:27:39 +0100 Subject: [PATCH 41/51] Maintain custom doc attributes, incl. test. --- spacy_llm/pipeline/llm.py | 18 +++++++- spacy_llm/tasks/noop.py | 13 +++++- spacy_llm/tasks/rel/util.py | 2 +- spacy_llm/tasks/sentiment/util.py | 2 +- spacy_llm/tasks/summarization/util.py | 2 +- spacy_llm/tests/pipeline/test_llm.py | 64 ++++++++++++++++++++++++++- 6 files changed, 93 insertions(+), 8 deletions(-) diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index 85490f4e..12a06d1a 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -197,8 +197,8 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: """Process a batch of docs with the configured LLM model and task. If a cache is configured, only sends prompts to model for docs not found in cache. - docs (List[Doc]): Input batch of docs - RETURNS (List[Doc]): Processed batch of docs with task annotations set + docs (List[Doc]): Input batch of docs. + RETURNS (List[Doc]): Processed batch of docs with task annotations set. """ has_shards = supports_sharding(self._task) is_cached = [doc in self._cache for doc in docs] @@ -210,6 +210,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: len(noncached_doc_batch), ) + # Process uncached docs. modified_docs: Iterator[Doc] = iter(()) if len(noncached_doc_batch) > 0: n_iters = 3 if self._save_io else 2 @@ -261,6 +262,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: ) modified_docs = iter(resp) + noncached_doc_batch_iter = iter(noncached_doc_batch) final_docs: List[Doc] = [] for i, doc in enumerate(docs): if is_cached[i]: @@ -271,6 +273,18 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]: else: doc = next(modified_docs) + # Merge with doc's prior custom data. + noncached_doc = next(noncached_doc_batch_iter) + for extension in dir(noncached_doc): + if not Doc.has_extension(extension): + Doc.set_extension(extension, default=None) + # Don't overwrite any non-None extension values in new doc. + if getattr(doc._, extension) is None: + setattr(doc._, extension, getattr(noncached_doc._, extension)) + doc.user_data = {**noncached_doc.user_data, **doc.user_data} + doc._context = noncached_doc._context + + # Save raw IO (prompt and response), if save_io is True. if self._save_io: # Make sure the `llm_io` field is set doc.user_data["llm_io"] = doc.user_data.get( diff --git a/spacy_llm/tasks/noop.py b/spacy_llm/tasks/noop.py index 12d101f9..dff68dc1 100644 --- a/spacy_llm/tasks/noop.py +++ b/spacy_llm/tasks/noop.py @@ -1,3 +1,4 @@ +import warnings from typing import Iterable, Optional, Tuple from spacy.tokens import Doc @@ -27,7 +28,17 @@ def generate_prompts( def parse_responses( self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] ) -> Iterable[Doc]: - return [list(shards_for_doc)[0] for shards_for_doc in shards] + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + docs = [ + Doc.from_docs(list(shards_for_doc), ensure_whitespace=True) + for shards_for_doc in shards + ] + return docs @property def prompt_template(self) -> str: diff --git a/spacy_llm/tasks/rel/util.py b/spacy_llm/tasks/rel/util.py index d45298c4..e06229d7 100644 --- a/spacy_llm/tasks/rel/util.py +++ b/spacy_llm/tasks/rel/util.py @@ -88,7 +88,7 @@ def reduce_shards_to_doc(task: RELTask, shards: Iterable[Doc]) -> Doc: warnings.filterwarnings( "ignore", category=UserWarning, - message=f".*Skipping Doc custom extension '{task.field}' while merging docs.", + message=".*Skipping .* while merging docs.", ) doc = Doc.from_docs(shards, ensure_whitespace=True) diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 3b47d495..4352b62c 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -36,7 +36,7 @@ def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc: warnings.filterwarnings( "ignore", category=UserWarning, - message=f".*Skipping Doc custom extension '{task.field}' while merging docs.", + message=".*Skipping .* while merging docs.", ) doc = Doc.from_docs(shards, ensure_whitespace=True) setattr( diff --git a/spacy_llm/tasks/summarization/util.py b/spacy_llm/tasks/summarization/util.py index 2388c55e..9ee479a7 100644 --- a/spacy_llm/tasks/summarization/util.py +++ b/spacy_llm/tasks/summarization/util.py @@ -33,7 +33,7 @@ def reduce_shards_to_doc(task: SummarizationTask, shards: Iterable[Doc]) -> Doc: warnings.filterwarnings( "ignore", category=UserWarning, - message=f".*Skipping Doc custom extension '{task.field}' while merging docs.", + message=".*Skipping .* while merging docs.", ) doc = Doc.from_docs(list(shards), ensure_whitespace=True) diff --git a/spacy_llm/tests/pipeline/test_llm.py b/spacy_llm/tests/pipeline/test_llm.py index 0742dde9..829b9154 100644 --- a/spacy_llm/tests/pipeline/test_llm.py +++ b/spacy_llm/tests/pipeline/test_llm.py @@ -18,7 +18,7 @@ from spacy_llm.pipeline import LLMWrapper from spacy_llm.registry import registry from spacy_llm.tasks import _LATEST_TASKS, make_noop_task -from spacy_llm.tasks.noop import _NOOP_PROMPT +from spacy_llm.tasks.noop import _NOOP_PROMPT, ShardingNoopTask from ...cache import BatchCache from ...registry.reader import fewshot_reader @@ -79,7 +79,7 @@ def test_llm_pipe(noop_config: Dict[str, Any], n_process: int, shard: bool): ) -@pytest.mark.parametrize("n_process", [1, 2]) +@pytest.mark.parametrize("n_process", [2]) def test_llm_pipe_with_cache(tmp_path: Path, n_process: int): """Test call .pipe() with pre-cached docs""" ops = get_current_ops() @@ -406,3 +406,63 @@ def test_llm_task_factories_ner(): assert len(doc.ents) > 0 for ent in doc.ents: assert ent.label_ in ["PER", "ORG", "LOC"] + + +@pytest.mark.parametrize("shard", [True, False]) +def test_llm_custom_data(noop_config: Dict[str, Any], shard: bool): + """Test whether custom doc data is preserved.""" + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpNoShards.v1"}}} + if not shard + else noop_config, + ) + + doc = nlp.make_doc("This is a test") + if not Doc.has_extension("test"): + Doc.set_extension("test", default=None) + doc._.test = "Test" + doc.user_data["test"] = "Test" + + doc = nlp(doc) + assert doc._.test == "Test" + assert doc.user_data["test"] == "Test" + + +def test_llm_custom_data_overwrite(noop_config: Dict[str, Any]): + """Test whether custom doc data is overwritten as expected.""" + + class NoopTaskWithCustomData(ShardingNoopTask): + def parse_responses( + self, shards: Iterable[Iterable[Doc]], responses: Iterable[Iterable[str]] + ) -> Iterable[Doc]: + docs = super().parse_responses(shards, responses) + for doc in docs: + doc._.test = "Test 2" + doc.user_data["test"] = "Test 2" + return docs + + @registry.llm_tasks("spacy.NoOpCustomData.v1") + def make_noopnoshards_task(): + return NoopTaskWithCustomData() + + nlp = spacy.blank("en") + nlp.add_pipe( + "llm", + config={**noop_config, **{"task": {"@llm_tasks": "spacy.NoOpCustomData.v1"}}}, + ) + doc = nlp.make_doc("This is a test") + for extension in ("test", "test_nooverwrite"): + if not Doc.has_extension(extension): + Doc.set_extension(extension, default=None) + doc._.test = "Test" + doc._.test_nooverwrite = "Test" + doc.user_data["test"] = "Test" + doc.user_data["test_nooverwrite"] = "Test" + + doc = nlp(doc) + assert doc._.test == "Test 2" + assert doc.user_data["test"] == "Test 2" + assert doc._.test_nooverwrite == "Test" + assert doc.user_data["test_nooverwrite"] == "Test" From 5585174585c2eab4ffa6579f3c8cc21ffcdfe8cd Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 1 Dec 2023 11:58:20 +0100 Subject: [PATCH 42/51] Filter merge warnings in textcat reduction. --- spacy_llm/tasks/textcat/util.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/spacy_llm/tasks/textcat/util.py b/spacy_llm/tasks/textcat/util.py index fd69d4b7..291f5d29 100644 --- a/spacy_llm/tasks/textcat/util.py +++ b/spacy_llm/tasks/textcat/util.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict from typing import Any, DefaultDict, Dict, Iterable, Optional @@ -66,7 +67,13 @@ def reduce_shards_to_doc(task: TextCatTask, shards: Iterable[Doc]) -> Doc: for cat, cat_score in shard.cats.items(): all_cats[cat] += cat_score * weight - doc = Doc.from_docs(shards, ensure_whitespace=True) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=".*Skipping .* while merging docs.", + ) + doc = Doc.from_docs(shards, ensure_whitespace=True) doc.cats = all_cats return doc From 6d3a4c851904564d6d3aea8267696c62eedd0ef7 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 1 Dec 2023 15:49:48 +0100 Subject: [PATCH 43/51] Add Zephyr and Yi classes. --- download.py | 6 ++ spacy_llm/models/hf/mistral.py | 9 +- spacy_llm/models/hf/yi.py | 136 ++++++++++++++++++++++++++ spacy_llm/models/hf/zephyr.py | 100 +++++++++++++++++++ spacy_llm/tests/models/test_yi.py | 68 +++++++++++++ spacy_llm/tests/models/test_zephyr.py | 68 +++++++++++++ 6 files changed, 382 insertions(+), 5 deletions(-) create mode 100644 download.py create mode 100644 spacy_llm/models/hf/yi.py create mode 100644 spacy_llm/models/hf/zephyr.py create mode 100644 spacy_llm/tests/models/test_yi.py create mode 100644 spacy_llm/tests/models/test_zephyr.py diff --git a/download.py b/download.py new file mode 100644 index 00000000..58d3a9ee --- /dev/null +++ b/download.py @@ -0,0 +1,6 @@ +from transformers import pipeline + +# huggingface_hub.hf_hub_download(repo_id="HuggingFaceH4/zephyr-7b-beta") +model = pipeline( + "text-generation", model="HuggingFaceH4/zephyr-7b-beta", resume_download=True +) diff --git a/spacy_llm/models/hf/mistral.py b/spacy_llm/models/hf/mistral.py index 7da0df1a..3626bf51 100644 --- a/spacy_llm/models/hf/mistral.py +++ b/spacy_llm/models/hf/mistral.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional from confection import SimpleFrozenDict @@ -94,13 +94,12 @@ def mistral_hf( name: Mistral.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), -) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]: +) -> Mistral: """Generates Mistral instance that can execute a set of prompts and return the raw responses. - name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names(). + name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. - RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return - the raw responses. + RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses. """ return Mistral( name=name, config_init=config_init, config_run=config_run, context_length=8000 diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py new file mode 100644 index 00000000..ff84fee9 --- /dev/null +++ b/spacy_llm/models/hf/yi.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from confection import SimpleFrozenDict + +from ...compat import Literal, transformers +from ...registry.util import registry +from .base import HuggingFace + + +class Yi(HuggingFace): + MODEL_NAMES = Literal[ # noqa: F722 + "Yi-34B", + "Yi-34B-Chat-8bits", + "Yi-6B-Chat", + "Yi-6B", + "Yi-6B-200K", + "Yi-34B-Chat", + "Yi-34B-Chat-4bits", + "Yi-34B-200K", + ] + + def __init__( + self, + name: MODEL_NAMES, + config_init: Optional[Dict[str, Any]], + config_run: Optional[Dict[str, Any]], + context_length: int, + ): + self._tokenizer: Optional["transformers.AutoTokenizer"] = None + self._is_instruct = "instruct" in name + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) + + assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) + + # Instantiate GenerationConfig object from config dict. + self._hf_config_run = transformers.GenerationConfig.from_pretrained( + self._name, **self._config_run + ) + # To avoid deprecation warning regarding usage of `max_length`. + self._hf_config_run.max_new_tokens = self._hf_config_run.max_length + + def init_model(self) -> Any: + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + self._name, use_fast=False + ) + init_cfg = self._config_init + device: Optional[str] = None + if "device" in init_cfg: + device = init_cfg.pop("device") + + model = transformers.AutoModelForCausalLM.from_pretrained( + self._name, **init_cfg, resume_download=True + ).eval() + if device: + model.to(device) + + return model + + @property + def hf_account(self) -> str: + return "01-ai" + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] + assert callable(self._tokenizer) + assert hasattr(self._model, "generate") + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + prompts_for_doc = list(prompts_for_doc) + + tokenized_input_ids = [ + self._model.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + ) + for prompt in prompts_for_doc + ] + tokenized_input_ids = [ + tp.to(self._model.device) for tp in tokenized_input_ids + ] + + # response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) + responses.append( + [ + self._tokenizer.decode( + self._model.generate( + input_ids=tok_ii, generation_config=self._hf_config_run + )[:, tok_ii.shape[1] :][0], + skip_special_tokens=True, + ) + for tok_ii in tokenized_input_ids + ] + ) + + return responses + + @staticmethod + def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: + default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() + return {**default_cfg_init, **{"torch_dtype": "auto"}}, { + **default_cfg_run, + **{ + "max_new_tokens": 256, + "do_sample": True, + "temperature": 0.7, + "top_k": 50, + "top_p": 0.95, + }, + } + + +@registry.llm_models("spacy.Yi.v1") +def mistral_yi( + name: Yi.MODEL_NAMES, + config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), + config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), +) -> Yi: + """Generates Yi instance that can execute a set of prompts and return the raw responses. + name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). + config_init (Optional[Dict[str, Any]]): HF config for initializing the model. + config_run (Optional[Dict[str, Any]]): HF config for running the model. + RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses. + """ + return Yi( + name=name, + config_init=config_init, + config_run=config_run, + context_length=200000 if "200K" in name else 32000, + ) diff --git a/spacy_llm/models/hf/zephyr.py b/spacy_llm/models/hf/zephyr.py new file mode 100644 index 00000000..6b9ca722 --- /dev/null +++ b/spacy_llm/models/hf/zephyr.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from confection import SimpleFrozenDict + +from ...compat import Literal, transformers +from ...registry.util import registry +from .base import HuggingFace + + +class Zephyr(HuggingFace): + MODEL_NAMES = Literal["zephyr-7b-beta"] # noqa: F722 + + def __init__( + self, + name: MODEL_NAMES, + config_init: Optional[Dict[str, Any]], + config_run: Optional[Dict[str, Any]], + context_length: int, + ): + super().__init__( + name=name, + config_init=config_init, + config_run=config_run, + context_length=context_length, + ) + + # Instantiate GenerationConfig object from config dict. + self._hf_config_run = transformers.GenerationConfig.from_pretrained( + self._name, **self._config_run + ) + # To avoid deprecation warning regarding usage of `max_length`. + self._hf_config_run.max_new_tokens = self._hf_config_run.max_length + + def init_model(self) -> Any: + return transformers.pipeline( + "text-generation", + model=self._name, + return_full_text=False, + **self._config_init + ) + + @property + def hf_account(self) -> str: + return "HuggingFaceH4" + + def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] + assert hasattr(self._model, "generate") + responses: List[List[str]] = [] + + for prompts_for_doc in prompts: + formatted_prompts_for_doc = [ + self._model.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) + for prompt in prompts_for_doc + ] + + responses.append( + [ + self._model(prompt, generation_config=self._hf_config_run)[0][ + "generated_text" + ] + for prompt in formatted_prompts_for_doc + ] + ) + + return responses + + @staticmethod + def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: + default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() + return default_cfg_init, { + **default_cfg_run, + **{ + "max_new_tokens": 256, + "do_sample": True, + "temperature": 0.7, + "top_k": 50, + "top_p": 0.95, + }, + } + + +@registry.llm_models("spacy.Zephyr.v1") +def zephyr_hf( + name: Zephyr.MODEL_NAMES, + config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), + config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), +) -> Zephyr: + """Generates Zephyr instance that can execute a set of prompts and return the raw responses. + name (Literal): Name of the Zephyr model. Has to be one of Zephyr.get_model_names(). + config_init (Optional[Dict[str, Any]]): HF config for initializing the model. + config_run (Optional[Dict[str, Any]]): HF config for running the model. + RETURNS (Zephyr): Zephyr instance that can execute a set of prompts and return the raw responses. + """ + return Zephyr( + name=name, config_init=config_init, config_run=config_run, context_length=8000 + ) diff --git a/spacy_llm/tests/models/test_yi.py b/spacy_llm/tests/models/test_yi.py new file mode 100644 index 00000000..26d147a4 --- /dev/null +++ b/spacy_llm/tests/models/test_yi.py @@ -0,0 +1,68 @@ +import copy + +import pytest +import spacy +from confection import Config # type: ignore[import] +from thinc.compat import has_torch_cuda_gpu + +from ...compat import torch + +_PIPE_CFG = { + "model": { + "@llm_models": "spacy.Yi.v1", + "name": "Yi-6B", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, +} + +_NLP_CONFIG = """ + +[nlp] +lang = "en" +pipeline = ["llm"] +batch_size = 128 + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NoOp.v1" + +[components.llm.model] +@llm_models = "spacy.Yi.v1" +name = "Yi-6B" +""" + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init(): + """Test initialization and simple run.""" + nlp = spacy.blank("en") + cfg = copy.deepcopy(_PIPE_CFG) + nlp.add_pipe("llm", config=cfg) + nlp("This is a test.") + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init_from_config(): + orig_config = Config().from_str(_NLP_CONFIG) + nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) + assert nlp.pipe_names == ["llm"] + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_invalid_model(): + orig_config = Config().from_str(_NLP_CONFIG) + config = copy.deepcopy(orig_config) + config["components"]["llm"]["model"]["name"] = "x" + with pytest.raises(ValueError, match="unexpected value; permitted"): + spacy.util.load_model_from_config(config, auto_fill=True) + torch.cuda.empty_cache() diff --git a/spacy_llm/tests/models/test_zephyr.py b/spacy_llm/tests/models/test_zephyr.py new file mode 100644 index 00000000..e026854a --- /dev/null +++ b/spacy_llm/tests/models/test_zephyr.py @@ -0,0 +1,68 @@ +import copy + +import pytest +import spacy +from confection import Config # type: ignore[import] +from thinc.compat import has_torch_cuda_gpu + +from ...compat import torch + +_PIPE_CFG = { + "model": { + "@llm_models": "spacy.Zephyr.v1", + "name": "zephyr-7b-beta", + }, + "task": {"@llm_tasks": "spacy.NoOp.v1"}, +} + +_NLP_CONFIG = """ + +[nlp] +lang = "en" +pipeline = ["llm"] +batch_size = 128 + +[components] + +[components.llm] +factory = "llm" + +[components.llm.task] +@llm_tasks = "spacy.NoOp.v1" + +[components.llm.model] +@llm_models = "spacy.Zephyr.v1" +name = "zephyr-7b-beta" +""" + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init(): + """Test initialization and simple run.""" + nlp = spacy.blank("en") + cfg = copy.deepcopy(_PIPE_CFG) + nlp.add_pipe("llm", config=cfg) + nlp("This is a test.") + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skip(reason="CI runner needs more GPU memory") +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_init_from_config(): + orig_config = Config().from_str(_NLP_CONFIG) + nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) + assert nlp.pipe_names == ["llm"] + torch.cuda.empty_cache() + + +@pytest.mark.gpu +@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +def test_invalid_model(): + orig_config = Config().from_str(_NLP_CONFIG) + config = copy.deepcopy(orig_config) + config["components"]["llm"]["model"]["name"] = "x" + with pytest.raises(ValueError, match="unexpected value; permitted"): + spacy.util.load_model_from_config(config, auto_fill=True) + torch.cuda.empty_cache() From 57acfe492bbccb581a073ece2c79787bdecd259e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 1 Dec 2023 18:49:49 +0100 Subject: [PATCH 44/51] Fix Yi model. --- download.py | 6 ------ spacy_llm/models/hf/__init__.py | 4 ++++ spacy_llm/models/hf/yi.py | 20 +++++++------------- spacy_llm/models/hf/zephyr.py | 1 - 4 files changed, 11 insertions(+), 20 deletions(-) delete mode 100644 download.py diff --git a/download.py b/download.py deleted file mode 100644 index 58d3a9ee..00000000 --- a/download.py +++ /dev/null @@ -1,6 +0,0 @@ -from transformers import pipeline - -# huggingface_hub.hf_hub_download(repo_id="HuggingFaceH4/zephyr-7b-beta") -model = pipeline( - "text-generation", model="HuggingFaceH4/zephyr-7b-beta", resume_download=True -) diff --git a/spacy_llm/models/hf/__init__.py b/spacy_llm/models/hf/__init__.py index b3afbb71..c683781e 100644 --- a/spacy_llm/models/hf/__init__.py +++ b/spacy_llm/models/hf/__init__.py @@ -5,6 +5,8 @@ from .mistral import mistral_hf from .openllama import openllama_hf from .stablelm import stablelm_hf +from .yi import yi_hf +from .zephyr import zephyr_hf __all__ = [ "HuggingFace", @@ -14,4 +16,6 @@ "mistral_hf", "openllama_hf", "stablelm_hf", + "yi_hf", + "zephyr_hf", ] diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py index ff84fee9..68af6a7e 100644 --- a/spacy_llm/models/hf/yi.py +++ b/spacy_llm/models/hf/yi.py @@ -66,15 +66,18 @@ def hf_account(self) -> str: return "01-ai" def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] - assert callable(self._tokenizer) assert hasattr(self._model, "generate") + assert hasattr(self._tokenizer, "apply_chat_template") + assert self._tokenizer + # assert callable(self._tokenizer.apply_chat_template) # type: ignore[union-attr] + responses: List[List[str]] = [] for prompts_for_doc in prompts: prompts_for_doc = list(prompts_for_doc) tokenized_input_ids = [ - self._model.tokenizer.apply_chat_template( + self._tokenizer.apply_chat_template( # type: ignore[union-attr] [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True, @@ -104,20 +107,11 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: @staticmethod def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() - return {**default_cfg_init, **{"torch_dtype": "auto"}}, { - **default_cfg_run, - **{ - "max_new_tokens": 256, - "do_sample": True, - "temperature": 0.7, - "top_k": 50, - "top_p": 0.95, - }, - } + return {**default_cfg_init, **{"torch_dtype": "auto"}}, default_cfg_run @registry.llm_models("spacy.Yi.v1") -def mistral_yi( +def yi_hf( name: Yi.MODEL_NAMES, config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), diff --git a/spacy_llm/models/hf/zephyr.py b/spacy_llm/models/hf/zephyr.py index 6b9ca722..75ca37ae 100644 --- a/spacy_llm/models/hf/zephyr.py +++ b/spacy_llm/models/hf/zephyr.py @@ -44,7 +44,6 @@ def hf_account(self) -> str: return "HuggingFaceH4" def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] - assert hasattr(self._model, "generate") responses: List[List[str]] = [] for prompts_for_doc in prompts: From 2f1a90564f0c1cb7333627d78a3b384ed166fffc Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 1 Dec 2023 18:50:12 +0100 Subject: [PATCH 45/51] Fix Yi model. --- spacy_llm/models/hf/yi.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py index 68af6a7e..d3d2277c 100644 --- a/spacy_llm/models/hf/yi.py +++ b/spacy_llm/models/hf/yi.py @@ -69,7 +69,6 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: assert hasattr(self._model, "generate") assert hasattr(self._tokenizer, "apply_chat_template") assert self._tokenizer - # assert callable(self._tokenizer.apply_chat_template) # type: ignore[union-attr] responses: List[List[str]] = [] @@ -77,7 +76,7 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: prompts_for_doc = list(prompts_for_doc) tokenized_input_ids = [ - self._tokenizer.apply_chat_template( # type: ignore[union-attr] + self._tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True, From 982106366e62e2e3d011287dbf533ca3088e9003 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 4 Dec 2023 16:43:10 +0100 Subject: [PATCH 46/51] Fix Yi and Zephyr processing. --- spacy_llm/models/hf/yi.py | 13 ++++++------- spacy_llm/models/hf/zephyr.py | 4 +++- spacy_llm/tasks/sentiment/util.py | 7 ++++++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py index d3d2277c..0e5a5c39 100644 --- a/spacy_llm/models/hf/yi.py +++ b/spacy_llm/models/hf/yi.py @@ -10,12 +10,12 @@ class Yi(HuggingFace): MODEL_NAMES = Literal[ # noqa: F722 "Yi-34B", - "Yi-34B-Chat-8bits", - "Yi-6B-Chat", + "Yi-34B-chat-8bits", + "Yi-6B-chat", "Yi-6B", "Yi-6B-200K", - "Yi-34B-Chat", - "Yi-34B-Chat-4bits", + "Yi-34B-chat", + "Yi-34B-chat-4bits", "Yi-34B-200K", ] @@ -77,7 +77,7 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: tokenized_input_ids = [ self._tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], + conversation=[{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True, return_tensors="pt", @@ -88,7 +88,6 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: tp.to(self._model.device) for tp in tokenized_input_ids ] - # response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) responses.append( [ self._tokenizer.decode( @@ -96,7 +95,7 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: input_ids=tok_ii, generation_config=self._hf_config_run )[:, tok_ii.shape[1] :][0], skip_special_tokens=True, - ) + ).strip("\n") for tok_ii in tokenized_input_ids ] ) diff --git a/spacy_llm/models/hf/zephyr.py b/spacy_llm/models/hf/zephyr.py index 75ca37ae..26d4aab5 100644 --- a/spacy_llm/models/hf/zephyr.py +++ b/spacy_llm/models/hf/zephyr.py @@ -51,7 +51,7 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: self._model.tokenizer.apply_chat_template( [{"role": "user", "content": prompt}], tokenize=False, - add_generation_prompt=True, + add_generation_prompt=False, ) for prompt in prompts_for_doc ] @@ -61,6 +61,8 @@ def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: self._model(prompt, generation_config=self._hf_config_run)[0][ "generated_text" ] + .replace("<|assistant|>", "") + .strip("\n") for prompt in formatted_prompts_for_doc ] ) diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index 4352b62c..8f01e037 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -42,7 +42,12 @@ def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc: setattr( doc._, task.field, - sum([score * weight for score, weight in zip(sent_scores, weights)]), + sum( + [ + (score if score else 0) * weight + for score, weight in zip(sent_scores, weights) + ] + ), ) return doc From 98e3e6cd0cb6b5b1efa0cee2e37c1ae7a7e03409 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 4 Dec 2023 16:43:40 +0100 Subject: [PATCH 47/51] Remove deprecated comment. --- spacy_llm/tests/tasks/test_sentiment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index d7ea6fec..aac85966 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -131,7 +131,6 @@ def test_sentiment_predict(cfg_string, request): orig_config = Config().from_str(cfg) nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) if cfg_string != "ext_template_cfg_string": - # with pytest.warns() as record: assert nlp("This is horrible.")._.sentiment == 0.0 assert 0 < nlp("This is meh.")._.sentiment <= 0.5 assert nlp("This is perfect.")._.sentiment == 1.0 From 3747a2f0bf33905049d119565a4932949a6c83b9 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 11 Dec 2023 17:55:23 +0100 Subject: [PATCH 48/51] Change model used for Yi tests. --- spacy_llm/tests/models/test_yi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy_llm/tests/models/test_yi.py b/spacy_llm/tests/models/test_yi.py index 26d147a4..665d9278 100644 --- a/spacy_llm/tests/models/test_yi.py +++ b/spacy_llm/tests/models/test_yi.py @@ -10,7 +10,7 @@ _PIPE_CFG = { "model": { "@llm_models": "spacy.Yi.v1", - "name": "Yi-6B", + "name": "Yi-6B-chat", }, "task": {"@llm_tasks": "spacy.NoOp.v1"}, } From b2dff8fda4f7244303f5ec087cdff2e14b51c5b7 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 12 Dec 2023 16:18:34 +0100 Subject: [PATCH 49/51] Incorporate feedback. --- spacy_llm/models/hf/yi.py | 2 +- spacy_llm/pipeline/llm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy_llm/models/hf/yi.py b/spacy_llm/models/hf/yi.py index 0e5a5c39..6cb7807e 100644 --- a/spacy_llm/models/hf/yi.py +++ b/spacy_llm/models/hf/yi.py @@ -115,7 +115,7 @@ def yi_hf( config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), ) -> Yi: """Generates Yi instance that can execute a set of prompts and return the raw responses. - name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). + name (Literal): Name of the Yi model. Has to be one of Yi.get_model_names(). config_init (Optional[Dict[str, Any]]): HF config for initializing the model. config_run (Optional[Dict[str, Any]]): HF config for running the model. RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses. diff --git a/spacy_llm/pipeline/llm.py b/spacy_llm/pipeline/llm.py index cf568ddf..f3edff55 100644 --- a/spacy_llm/pipeline/llm.py +++ b/spacy_llm/pipeline/llm.py @@ -70,7 +70,7 @@ def make_llm( nlp (Language): Pipeline. name (str): The component instance name, used to add entries to the losses during training. - task (Optional[_LLMTask]): An _LLMTask can generate prompts for given docs, and can parse the LLM's responses into + task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses into structured information and set that back on the docs. model (Callable[[Iterable[Any]], Iterable[Any]]]): Callable querying the specified LLM API. cache (Cache): Cache to use for caching prompts and responses per doc (batch). From dfe89ee7180fe025ba8cdda45029c33e0aa72045 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 13 Dec 2023 11:08:57 +0100 Subject: [PATCH 50/51] Skip Yi test failing in CI, but suceeding locally. --- spacy_llm/tests/models/test_yi.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spacy_llm/tests/models/test_yi.py b/spacy_llm/tests/models/test_yi.py index 665d9278..95b73136 100644 --- a/spacy_llm/tests/models/test_yi.py +++ b/spacy_llm/tests/models/test_yi.py @@ -38,6 +38,9 @@ @pytest.mark.gpu @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") +@pytest.mark.skip( + reason="CI runner fails with 'cutlassF: no kernel found to launch!' - to be investigated" +) def test_init(): """Test initialization and simple run.""" nlp = spacy.blank("en") From 69c3c76f9f4c403c7180adb30aa5b08885877c83 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Wed, 13 Dec 2023 22:30:43 +0100 Subject: [PATCH 51/51] Extend readme with links for Zephyr and Yi. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index aeff0e00..ee1a7fb5 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ This package integrates Large Language Models (LLMs) into [spaCy](https://spacy. - **[OpenLLaMA](https://huggingface.co/openlm-research)** - **[StableLM](https://huggingface.co/stabilityai)** - **[Mistral](https://huggingface.co/mistralai)** + - **[Zephyr](https://huggingface.co/HuggingFaceH4)** + - **[Yi](https://huggingface.co/01-ai)** - Integration with [LangChain](https://github.com/hwchase17/langchain) 🦜️🔗 - all `langchain` models and features can be used in `spacy-llm` - Tasks available out of the box: - Named Entity Recognition