From bb54de250ac0ae101e74a4161a535507b33c5989 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Thu, 13 Jul 2023 09:40:03 +0200 Subject: [PATCH] Add support for Claude-2. --- spacy_llm/models/rest/anthropic/model.py | 4 +++ spacy_llm/models/rest/anthropic/registry.py | 34 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index b7bb5968..f8746376 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -26,6 +26,7 @@ class SystemPrompt(str, Enum): class Anthropic(REST): MODEL_NAMES = { "claude-1": Literal["claude-1", "claude-1-100k"], + "claude-2": Literal["claude-2", "claude-2-100k"], } @property @@ -105,6 +106,9 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: @classmethod def get_model_names(cls) -> Tuple[str, ...]: return ( + # claude-2 + "claude-2", + "claude-2-100k", # claude-1 "claude-1", "claude-1-100k", diff --git a/spacy_llm/models/rest/anthropic/registry.py b/spacy_llm/models/rest/anthropic/registry.py index 812f401d..f19bf410 100644 --- a/spacy_llm/models/rest/anthropic/registry.py +++ b/spacy_llm/models/rest/anthropic/registry.py @@ -7,6 +7,40 @@ from .model import Anthropic, Endpoints +@registry.llm_models("spacy.Claude-2.v1") +def anthropic_claude_2( + config: Dict[Any, Any] = SimpleFrozenDict(), + name: Literal["claude-2", "claude-2-100k"] = "claude-2", # noqa: F722 + strict: bool = Anthropic.DEFAULT_STRICT, + max_tries: int = Anthropic.DEFAULT_MAX_TRIES, + interval: float = Anthropic.DEFAULT_INTERVAL, + max_request_time: float = Anthropic.DEFAULT_MAX_REQUEST_TIME, +) -> Callable[[Iterable[str]], Iterable[str]]: + """Returns Anthropic instance for 'claude-2' model using REST to prompt API. + config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. + name (Literal["claude-2", "claude-2-100k"]): Model to use. + strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON + or other response object that does not conform to the expectation of how a well-formed response object from + this API should look like). If False, the API error responses are returned by __call__(), but no error will + be raised. + max_tries (int): Max. number of tries for API request. + interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff + at each retry. + max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. + RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Anthropic instance for 'claude-1' model using REST to + prompt API. + """ + return Anthropic( + name=name, + endpoint=Endpoints.COMPLETIONS, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + ) + + @registry.llm_models("spacy.Claude-1.v1") def anthropic_claude_1( config: Dict[Any, Any] = SimpleFrozenDict(),