diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index a7c03d00f90..b0198e6cadf 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -1,7 +1,7 @@ import re from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import srsly from jinja2 import Template @@ -9,9 +9,10 @@ from wasabi import Printer, diff_strings from .. import util +from ..errors import Errors from ..language import DEFAULT_CONFIG_PRETRAIN_PATH from ..schemas import RecommendationSchema -from ..util import SimpleFrozenList +from ..util import SimpleFrozenList, registry from ._util import ( COMMAND, Arg, @@ -40,6 +41,8 @@ class InitValues: lang = "en" pipeline = SimpleFrozenList(["tagger", "parser", "ner"]) + llm_task: Optional[str] = None + llm_model: Optional[str] = None optimize = Optimizations.efficiency gpu = False pretraining = False @@ -52,6 +55,8 @@ def init_config_cli( output_file: Path = Arg(..., help="File to save the config to or - for stdout (will only output config and no additional logging info)", allow_dash=True), lang: str = Opt(InitValues.lang, "--lang", "-l", help="Two-letter code of the language to use"), pipeline: str = Opt(",".join(InitValues.pipeline), "--pipeline", "-p", help="Comma-separated names of trainable pipeline components to include (without 'tok2vec' or 'transformer')"), + llm_task: str = Opt(InitValues.llm_task, "--llm.task", help="Name of task for LLM pipeline components"), + llm_model: str = Opt(InitValues.llm_model, "--llm.model", help="Name of model for LLM pipeline components"), optimize: Optimizations = Opt(InitValues.optimize, "--optimize", "-o", help="Whether to optimize for efficiency (faster inference, smaller model, lower memory consumption) or higher accuracy (potentially larger and slower model). This will impact the choice of architecture, pretrained weights and related hyperparameters."), gpu: bool = Opt(InitValues.gpu, "--gpu", "-G", help="Whether the model can run on GPU. This will impact the choice of architecture, pretrained weights and related hyperparameters."), pretraining: bool = Opt(InitValues.pretraining, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"), @@ -77,6 +82,8 @@ def init_config_cli( config = init_config( lang=lang, pipeline=pipeline, + llm_model=llm_model, + llm_task=llm_task, optimize=optimize.value, gpu=gpu, pretraining=pretraining, @@ -157,6 +164,8 @@ def init_config( *, lang: str = InitValues.lang, pipeline: List[str] = InitValues.pipeline, + llm_model: Optional[str] = InitValues.llm_model, + llm_task: Optional[str] = InitValues.llm_task, optimize: str = InitValues.optimize, gpu: bool = InitValues.gpu, pretraining: bool = InitValues.pretraining, @@ -165,8 +174,57 @@ def init_config( msg = Printer(no_print=silent) with TEMPLATE_PATH.open("r") as f: template = Template(f.read()) + # Filter out duplicates since tok2vec and transformer are added by template pipeline = [pipe for pipe in pipeline if pipe not in ("tok2vec", "transformer")] + + # Verify LLM arguments are consistent, if at least one `llm` component has been specified. + llm_spec: Dict[str, Dict[str, Any]] = {} + if "llm" in pipeline: + try: + import spacy_llm + except ImportError as ex: + raise ValueError(Errors.E1055) from ex + + if llm_model is None: + msg.fail( + "Option `--llm.model` must be set if `llm` component is in pipeline.", + exits=1, + ) + if llm_task is None: + msg.fail( + "Option `--llm.task` must be set if `llm` component is in pipeline.", + exits=1, + ) + + # Select registry handles for model(s) and task(s). Raise if no match found. + llm_spec = { + spec_type: { + "arg": llm_model if spec_type == "model" else llm_task, + "matched_reg_handle": None, + "reg_handles": getattr(registry, f"llm_{spec_type}s").get_all(), + } + for spec_type in ("model", "task") + } + + for spec_type, spec in llm_spec.items(): + valid_values = set() + user_value = spec["arg"].lower().replace(".", "-") + for reg_handle in spec["reg_handles"]: + reg_name = reg_handle.split(".")[1] + valid_values.add(reg_name) + if reg_name.lower() == user_value: + spec["matched_reg_handle"] = reg_handle + break + + if not spec["matched_reg_handle"]: + arg = spec["arg"] + msg.fail( + f"Couldn't find a matching registration handle for {spec_type} '{arg}'. " + f"Valid options are: {valid_values}", + exits=1, + ) + defaults = RECOMMENDATIONS["__default__"] reco = RecommendationSchema(**RECOMMENDATIONS.get(lang, defaults)).dict() variables = { @@ -175,6 +233,7 @@ def init_config( "optimize": optimize, "hardware": "gpu" if gpu else "cpu", "transformer_data": reco["transformer"], + "llm_spec": {key: llm_spec[key]["matched_reg_handle"] for key in llm_spec}, "word_vectors": reco["word_vectors"], "has_letters": reco["has_letters"], } diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 1937ea93533..304a8354a07 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -3,7 +3,7 @@ the docs and the init config command. It encodes various best practices and can help generate the best possible configuration, given a user's requirements. #} {%- set use_transformer = hardware != "cpu" and transformer_data -%} {%- set transformer = transformer_data[optimize] if use_transformer else {} -%} -{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "span_finder", "spancat", "spancat_singlelabel", "trainable_lemmatizer"] -%} +{%- set listener_components = ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker", "span_finder", "spancat", "spancat_singlelabel", "trainable_lemmatizer", "llm"] -%} [paths] train = null dev = null @@ -328,6 +328,18 @@ grad_factor = 1.0 {%- endif %} {%- endif %} +{% if "llm" in components -%} +[components.llm] +factory = "llm" + +[components.llm.model] +@llm_models = "{{ llm_spec['model'] }}" + +[components.llm.task] +@llm_tasks = "{{ llm_spec['task'] }}" +{% endif -%} + + {# NON-TRANSFORMER PIPELINE #} {% else -%} {% if "tok2vec" in full_pipeline -%} @@ -585,6 +597,17 @@ no_output_layer = false {%- endif %} {% endif %} +{% if "llm" in components -%} +[components.llm] +factory = "llm" + +[components.llm.model] +@llm_models = "{{ llm_spec['model'] }}" + +[components.llm.task] +@llm_tasks = "{{ llm_spec['task'] }}" +{% endif -%} + {% for pipe in components %} {% if pipe not in listener_components %} {# Other components defined by the user: we just assume they're factories #} diff --git a/spacy/errors.py b/spacy/errors.py index db1a886aa8f..b35dc1aa2c8 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -981,6 +981,8 @@ class Errors(metaclass=ErrorsWithCodes): " 'min_length': {min_length}, 'max_length': {max_length}") E1054 = ("The text, including whitespace, must match between reference and " "predicted docs when training {component}.") + E1055 = ("To use the `llm` component, `spacy-llm` needs to be installed. `spacy-llm` was not found in your " + "environment, install it with `pip install spacy-llm`.") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 8e1c9ca3215..fdaaf5359ba 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -628,7 +628,6 @@ def test_parse_cli_overrides(): "pipeline", [ ["tagger", "parser", "ner"], - [], ["ner", "textcat", "sentencizer"], ["morphologizer", "spancat", "entity_linker"], ["spancat_singlelabel", "textcat_multilabel"], @@ -651,6 +650,26 @@ def test_init_config(lang, pipeline, optimize, pretraining): load_model_from_config(config, auto_fill=True) +@pytest.mark.parametrize("pipeline", [["llm"]]) +@pytest.mark.parametrize("llm_model", ["noop"]) +@pytest.mark.parametrize("llm_task", ["ner", "sentiment"]) +def test_init_config_llm(pipeline, llm_model, llm_task): + config = init_config( + lang="en", + pipeline=pipeline, + llm_model=llm_model, + llm_task=llm_task, + optimize="accuracy", + pretraining=False, + gpu=False, + ) + assert isinstance(config, Config) + assert len(config["components"]) == 1 + assert "llm" in config["components"] + + load_model_from_config(config, auto_fill=True) + + def test_model_recommendations(): for lang, data in RECOMMENDATIONS.items(): assert RecommendationSchema(**data)