Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add registry functions to instantiate models by provider #428

Merged
merged 17 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ factory = "llm"
labels = ["COMPLIMENT", "INSULT"]

[components.llm.model]
@llm_models = "spacy.GPT-4.v2"
@llm_models = "spacy.OpenAI.v1"
name = "gpt-4"
```

Now run:
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ filterwarnings = [
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*"
"ignore:^.*was deprecated in langchain-community.*",
"ignore:^.*was deprecated in LangChain 0.0.1.*",
"ignore:^.*the load_module() method is deprecated and slated for removal in Python 3.12.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ langchain>=0.1,<0.2; python_version>="3.9"
openai>=0.27,<=0.28.1; python_version>="3.9"

# Necessary for running all local models on GPU.
transformers[sentencepiece]>=4.0.0
# TODO: transformers > 4.38 causes bug in model handling due to unknown factors. To be investigated.
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
transformers[sentencepiece]>=4.0.0,<=4.38
torch
einops>=0.4

Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .registry import huggingface_v1
from .stablelm import stablelm_hf

__all__ = [
"HuggingFace",
"dolly_hf",
"falcon_hf",
"huggingface_v1",
"llama2_hf",
"mistral_hf",
"openllama_hf",
Expand Down
3 changes: 1 addition & 2 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def mistral_hf(
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
the raw responses.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
51 changes: 51 additions & 0 deletions spacy_llm/models/hf/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Any, Dict, Optional

from confection import SimpleFrozenDict

from ...registry import registry
from .base import HuggingFace
from .dolly import Dolly
from .falcon import Falcon
from .llama2 import Llama2
from .mistral import Mistral
from .openllama import OpenLLaMA
from .stablelm import StableLM


@registry.llm_models("spacy.HF.v1")
@registry.llm_models("spacy.HuggingFace.v1")
def huggingface_v1(
name: str,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> HuggingFace:
"""Returns HuggingFace model instance.
name (str): Name of model to use.
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
the raw responses.
"""
model_context_lengths = {
Dolly: 2048,
Falcon: 2048,
Llama2: 4096,
Mistral: 8000,
OpenLLaMA: 2048,
StableLM: 4096,
}

for model_cls, context_length in model_context_lengths.items():
model_names = getattr(model_cls, "MODEL_NAMES")
if model_names and name in model_names.__args__:
return model_cls(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

raise ValueError(
f"Name {name} could not be associated with any of the supported models. Please check "
f"https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct."
)
1 change: 1 addition & 0 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def query_langchain(
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
assert callable(model)
return [
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
]
Expand Down
37 changes: 37 additions & 0 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Anthropic.v1")
def anthropic_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Anthropic:
"""Returns Anthropic model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Anthropic): Instance of Anthropic model.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Claude-2.v2")
def anthropic_claude_2_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down
39 changes: 38 additions & 1 deletion spacy_llm/models/rest/cohere/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Cohere, Endpoints


@registry.llm_models("spacy.Cohere.v1")
def cohere_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Cohere.DEFAULT_STRICT,
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
interval: float = Cohere.DEFAULT_INTERVAL,
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Cohere:
"""Returns Cohere model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Cohere): Instance of Cohere model.
"""
return Cohere(
name=name,
endpoint=Endpoints.COMPLETION.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Command.v2")
def cohere_command_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down Expand Up @@ -56,7 +93,7 @@ def cohere_command(
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Cohere instance for 'command' model using REST to prompt API.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
41 changes: 41 additions & 0 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,47 @@

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.OpenAI.v1")
def openai_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> OpenAI:
"""Returns OpenAI model instance using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (OpenAI): OpenAI model instance.
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
Expand Down
46 changes: 44 additions & 2 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,48 @@
from .model import Endpoints, PaLM


@registry.llm_models("spacy.Google.v1")
def google_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
endpoint: Optional[str] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google model instance using REST to prompt API.
name (str): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
RETURNS (PaLM): PaLM model instance.
"""
default_endpoint = (
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
)
return PaLM(
name=name,
endpoint=endpoint or default_endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=None,
)


@registry.llm_models("spacy.PaLM.v2")
def palm_bison_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
Expand All @@ -18,7 +60,7 @@ def palm_bison_v2(
context_length: Optional[int] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down Expand Up @@ -57,7 +99,7 @@ def palm_bison(
endpoint: Optional[str] = None,
) -> PaLM:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
7 changes: 4 additions & 3 deletions spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
logger.addHandler(logging.NullHandler())

DEFAULT_MODEL_CONFIG = {
"@llm_models": "spacy.GPT-3-5.v2",
"@llm_models": "spacy.GPT-3-5.v3",
"strict": True,
}
DEFAULT_CACHE_CONFIG = {
Expand Down Expand Up @@ -238,6 +238,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
else self._task.generate_prompts(noncached_doc_batch),
n_iters + 1,
)

responses_iters = tee(
self._model(
# Ensure that model receives Iterable[Iterable[Any]]. If task doesn't shard, its prompt is wrapped
Expand All @@ -251,7 +252,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
)

for prompt_data, response, doc in zip(
prompts_iters[1], responses_iters[0], noncached_doc_batch
prompts_iters[1], list(responses_iters[0]), noncached_doc_batch
):
logger.debug(
"Generated prompt for doc: %s\n%s",
Expand All @@ -266,7 +267,7 @@ def _process_docs(self, docs: List[Doc]) -> List[Doc]:
elem[1] if support_sharding else noncached_doc_batch[i]
for i, elem in enumerate(prompts_iters[2])
),
responses_iters[1],
list(responses_iters[1]),
)
)

Expand Down
2 changes: 1 addition & 1 deletion spacy_llm/tests/models/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_cohere_api_response_when_error():
def test_cohere_error_unsupported_model():
"""Ensure graceful handling of error when model is not supported"""
incorrect_model = "x-gpt-3.5-turbo"
with pytest.raises(ValueError, match="model not found"):
with pytest.raises(ValueError, match="Request to Cohere API failed"):
Cohere(
name=incorrect_model,
config={},
Expand Down
Loading
Loading