Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add registry functions to instantiate models by provider #428

Merged
merged 17 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ factory = "llm"
labels = ["COMPLIMENT", "INSULT"]

[components.llm.model]
@llm_models = "spacy.GPT-4.v2"
@llm_models = "spacy.OpenAI.v1"
name = "gpt-4"
```

Now run:
Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from .llama2 import llama2_hf
from .mistral import mistral_hf
from .openllama import openllama_hf
from .registry import huggingface_v1
from .stablelm import stablelm_hf

__all__ = [
"HuggingFace",
"dolly_hf",
"falcon_hf",
"huggingface_v1",
"llama2_hf",
"mistral_hf",
"openllama_hf",
Expand Down
3 changes: 1 addition & 2 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def mistral_hf(
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Mistral instance that can execute a set of prompts and return
the raw responses.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
50 changes: 50 additions & 0 deletions spacy_llm/models/hf/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, Callable, Dict, Iterable, Optional

from confection import SimpleFrozenDict

from ...registry import registry
from .dolly import Dolly
from .falcon import Falcon
from .llama2 import Llama2
from .mistral import Mistral
from .openllama import OpenLLaMA
from .stablelm import StableLM


@registry.llm_models("spacy.HF.v1")
@registry.llm_models("spacy.HuggingFace.v1")
def huggingface_v1(
name: str,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns HuggingFace model instance.
name (str): Name of model to use.
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Model instance that can execute a set of prompts and return
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
the raw responses.
"""
model_context_lengths = {
Dolly: 2048,
Falcon: 2048,
Llama2: 4096,
Mistral: 8000,
OpenLLaMA: 2048,
StableLM: 4096,
}

for model_cls, context_length in model_context_lengths.items():
model_names = getattr(model_cls, "MODEL_NAMES")
if model_names and name in model_names.__args__:
return model_cls(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

raise ValueError(
f"Name {name} could not be associated with any of the supported models. Please check "
f"https://spacy.io/api/large-language-models#models-hf to ensure the specified model name is correct."
)
37 changes: 37 additions & 0 deletions spacy_llm/models/rest/anthropic/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Anthropic, Endpoints


@registry.llm_models("spacy.Anthropic.v1")
def anthropic_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Anthropic.DEFAULT_STRICT,
max_tries: int = Anthropic.DEFAULT_MAX_TRIES,
interval: float = Anthropic.DEFAULT_INTERVAL,
max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Anthropic:
"""Returns Anthropic model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Anthropic): Instance of Anthropic model.
"""
return Anthropic(
name=name,
endpoint=Endpoints.COMPLETIONS.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Claude-2.v2")
def anthropic_claude_2_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down
39 changes: 38 additions & 1 deletion spacy_llm/models/rest/cohere/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,43 @@
from .model import Cohere, Endpoints


@registry.llm_models("spacy.Cohere.v1")
def cohere_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(),
strict: bool = Cohere.DEFAULT_STRICT,
max_tries: int = Cohere.DEFAULT_MAX_TRIES,
interval: float = Cohere.DEFAULT_INTERVAL,
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
) -> Cohere:
"""Returns Cohere model instance using REST to prompt API.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
name (str): Name of model to use.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (Cohere): Instance of Cohere model.
"""
return Cohere(
name=name,
endpoint=Endpoints.COMPLETION.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


@registry.llm_models("spacy.Command.v2")
def cohere_command_v2(
config: Dict[Any, Any] = SimpleFrozenDict(),
Expand Down Expand Up @@ -56,7 +93,7 @@ def cohere_command(
max_request_time: float = Cohere.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Cohere instance for 'command' model using REST to prompt API.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Model to use.
name (Literal["command", "command-light", "command-light-nightly", "command-nightly"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
43 changes: 43 additions & 0 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,49 @@

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.OpenAI.v1")
def openai_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
endpoint: Optional[str] = None,
context_length: Optional[int] = None,
) -> OpenAI:
"""Returns OpenAI model instance using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
endpoint (Optional[str]): Endpoint to set. Defaults to standard endpoint.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
RETURNS (OpenAI): OpenAI model instance.

DOCS: https://spacy.io/api/large-language-models#models
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
"""
return OpenAI(
name=name,
endpoint=endpoint or Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=context_length,
)


"""
Parameter explanations:
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
Expand Down
46 changes: 44 additions & 2 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,48 @@
from .model import Endpoints, PaLM


@registry.llm_models("spacy.Google.v1")
def google_v1(
name: str,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
context_length: Optional[int] = None,
endpoint: Optional[str] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google model instance using REST to prompt API.
name (str): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
context_length (Optional[int]): Context length for this model. Only necessary for sharding and if no context length
natively provided by spacy-llm.
endpoint (Optional[str]): Endpoint to use. Defaults to standard endpoint.
RETURNS (PaLM): PaLM model instance.
"""
default_endpoint = (
Endpoints.TEXT.value if name in {"text-bison-001"} else Endpoints.MSG.value
)
return PaLM(
name=name,
endpoint=endpoint or default_endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
context_length=None,
)


@registry.llm_models("spacy.PaLM.v2")
def palm_bison_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
Expand All @@ -18,7 +60,7 @@ def palm_bison_v2(
context_length: Optional[int] = None,
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down Expand Up @@ -57,7 +99,7 @@ def palm_bison(
endpoint: Optional[str] = None,
) -> PaLM:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
name (Literal["chat-bison-001", "text-bison-001"]): Name of model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
Expand Down
6 changes: 3 additions & 3 deletions spacy_llm/tests/models/test_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Dolly.v1",
"@llm_models": "spacy.HuggingFace.v1",
"name": "dolly-v2-3b",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
Expand All @@ -32,7 +32,7 @@
@llm_tasks = "spacy.NoOp.v1"

[components.llm.model]
@llm_models = "spacy.Dolly.v1"
@llm_models = "spacy.HuggingFace.v1"
name = "dolly-v2-3b"
"""

Expand Down Expand Up @@ -66,6 +66,6 @@ def test_invalid_model():
orig_config = Config().from_str(_NLP_CONFIG)
config = copy.deepcopy(orig_config)
config["components"]["llm"]["model"]["name"] = "dolly-the-sheep"
with pytest.raises(ValueError, match="unexpected value; permitted"):
with pytest.raises(ValueError, match="could not be associated"):
spacy.util.load_model_from_config(config, auto_fill=True)
torch.cuda.empty_cache()
4 changes: 2 additions & 2 deletions spacy_llm/tests/models/test_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Falcon.v1",
"@llm_models": "spacy.HuggingFace.v1",
"name": "falcon-rw-1b",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
Expand All @@ -32,7 +32,7 @@
@llm_tasks = "spacy.NoOp.v1"

[components.llm.model]
@llm_models = "spacy.Falcon.v1"
@llm_models = "spacy.HuggingFace.v1"
name = "falcon-rw-1b"
"""

Expand Down
12 changes: 6 additions & 6 deletions spacy_llm/tests/models/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
@pytest.mark.parametrize(
"model", (("spacy.Dolly.v1", "dolly-v2-3b"), ("spacy.Llama2.v1", "Llama-2-7b-hf"))
)
@pytest.mark.parametrize("model", ("dolly-v2-3b", "Llama-2-7b-hf"))
def test_device_config_conflict(model: Tuple[str, str]):
"""Test device configuration."""
nlp = spacy.blank("en")
model, name = model
cfg = {**_PIPE_CFG, **{"model": {"@llm_models": model, "name": name}}}
cfg = {
**_PIPE_CFG,
**{"model": {"@llm_models": "spacy.HuggingFace.v1", "name": model}},
}

# Set device only.
cfg["model"]["config_init"] = {"device": "cpu"} # type: ignore[index]
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_torch_dtype():
nlp = spacy.blank("en")
cfg = {
**_PIPE_CFG,
**{"model": {"@llm_models": "spacy.Dolly.v1", "name": "dolly-v2-3b"}},
**{"model": {"@llm_models": "spacy.HuggingFace.v1", "name": "dolly-v2-3b"}},
}

# Should be converted to torch.float16.
Expand Down
4 changes: 2 additions & 2 deletions spacy_llm/tests/models/test_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Llama2.v1",
"@llm_models": "spacy.HuggingFace.v1",
"name": "Llama-2-7b-hf",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
Expand All @@ -32,7 +32,7 @@
@llm_tasks = "spacy.NoOp.v1"

[components.llm.model]
@llm_models = "spacy.Llama2.v1"
@llm_models = "spacy.HuggingFace.v1"
name = "Llama-2-7b-hf"
"""

Expand Down
Loading
Loading