Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new OS models: Zephyr and Yi #392

Draft
wants to merge 58 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
3aba660
Add context length info. Refactor BuiltinTask and models to facilitat…
rmitsch Oct 17, 2023
5699773
Merge branch 'develop' into feat/inf-doc-len
rmitsch Oct 17, 2023
4213372
Add token count estimator plumbing.
rmitsch Oct 17, 2023
f440ca4
Add plumbing for mapper and reducer.
rmitsch Oct 17, 2023
e47f762
Add ShardMapper prototype.
rmitsch Oct 18, 2023
89a5510
Integrating mapping into prompt generation workflow.
rmitsch Oct 19, 2023
086dec9
Update response parsing and component to support sharding (WIP).
rmitsch Oct 20, 2023
23718fc
Fix shard & prompt flow.
rmitsch Oct 27, 2023
7ce670d
Fix shard & prompt flow.
rmitsch Oct 27, 2023
0d75ea8
Remove todo comments.
rmitsch Oct 27, 2023
9da7098
Fix Anthropic, Cohere, NoOp model tests.
rmitsch Oct 27, 2023
0cb9afd
Merge branch 'develop' into feat/inf-doc-len
rmitsch Oct 30, 2023
f368412
Fix test_llm_pipe().
rmitsch Oct 31, 2023
b1f111d
Fix type checking test.
rmitsch Nov 3, 2023
44a2787
Fix span parsing tests.
rmitsch Nov 3, 2023
6d8cdc7
Fix internal tests.
rmitsch Nov 3, 2023
e712f41
Fix _CountTask.
rmitsch Nov 3, 2023
985fd68
Fix sentiment and summarization tasks and tests.
rmitsch Nov 3, 2023
98842a2
Fix Azure connection URL. Fix Model test pings.
rmitsch Nov 3, 2023
b54a3d9
Fix Lemma parsing.
rmitsch Nov 3, 2023
9bf365d
Start work on doc-to-shard property copying.
rmitsch Nov 3, 2023
dddfaab
Fix REL doc preprocessing.
rmitsch Nov 6, 2023
3af21b5
Remove comment on doc attribute handling during sharding, as this is …
rmitsch Nov 6, 2023
fee9ca7
Add reducer implementations.
rmitsch Nov 8, 2023
e508499
Implement outstanding task reducers.
rmitsch Nov 14, 2023
3218541
Resolve merge conflicts.
rmitsch Nov 14, 2023
c104387
Add shardable/non-shardable LLM task typing distinction. Add support …
rmitsch Nov 20, 2023
2c6d899
Merge branch 'develop' into feat/inf-doc-len
rmitsch Nov 21, 2023
2502c4d
Fix EL task.
rmitsch Nov 23, 2023
03055c5
Fix EL tokenization and highlighting partially.
rmitsch Nov 23, 2023
4e4a2cd
Fix tokenization and whitespaces for EL task.
rmitsch Nov 24, 2023
865acec
Fix merge conflicts.
rmitsch Nov 24, 2023
694d5da
Add new registry handlers (with context length and arbitrary model na…
rmitsch Nov 24, 2023
5295400
Add sharding test with simple count task.
rmitsch Nov 24, 2023
70e3643
Fix sharding algorithm.
rmitsch Nov 24, 2023
4321483
Add test with simple count task.
rmitsch Nov 27, 2023
ef6e738
Add context length as init arg in HF models.
rmitsch Nov 27, 2023
e3ff37d
Fix tests. Don't stringify IO lists if sharded.
rmitsch Nov 28, 2023
056730a
Fix tests.
rmitsch Nov 29, 2023
196c235
Add NER sharding test.
rmitsch Nov 29, 2023
1f51a4a
Add REL and sentiment sharding tests.
rmitsch Nov 29, 2023
e18b302
Add summary sharding tests.
rmitsch Nov 29, 2023
7c092ca
Add EL sharding task. Fix bug in shard mapper.
rmitsch Nov 29, 2023
358ba72
Fix REL error with RELExample parsing.
rmitsch Nov 29, 2023
0c96fb6
Use regex for punctuation in REL conversion.
rmitsch Nov 29, 2023
dc926bd
Maintain custom doc attributes, incl. test.
rmitsch Dec 1, 2023
5585174
Filter merge warnings in textcat reduction.
rmitsch Dec 1, 2023
6d3a4c8
Add Zephyr and Yi classes.
rmitsch Dec 1, 2023
57acfe4
Fix Yi model.
rmitsch Dec 1, 2023
2f1a905
Fix Yi model.
rmitsch Dec 1, 2023
9821063
Fix Yi and Zephyr processing.
rmitsch Dec 4, 2023
98e3e6c
Remove deprecated comment.
rmitsch Dec 4, 2023
513c2fb
Fix merge conflicts.
rmitsch Dec 11, 2023
482af35
Merge branch 'develop' into feat/new-os-models
rmitsch Dec 11, 2023
3747a2f
Change model used for Yi tests.
rmitsch Dec 11, 2023
b2dff8f
Incorporate feedback.
rmitsch Dec 12, 2023
dfe89ee
Skip Yi test failing in CI, but suceeding locally.
rmitsch Dec 13, 2023
69c3c76
Extend readme with links for Zephyr and Yi.
rmitsch Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spacy_llm/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .mistral import mistral_hf
from .openllama import openllama_hf
from .stablelm import stablelm_hf
from .yi import yi_hf
from .zephyr import zephyr_hf

__all__ = [
"HuggingFace",
Expand All @@ -14,4 +16,6 @@
"mistral_hf",
"openllama_hf",
"stablelm_hf",
"yi_hf",
"zephyr_hf",
]
9 changes: 4 additions & 5 deletions spacy_llm/models/hf/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable, List, Optional

from confection import SimpleFrozenDict

Expand Down Expand Up @@ -94,13 +94,12 @@ def mistral_hf(
name: Mistral.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Callable[[Iterable[Iterable[str]]], Iterable[Iterable[str]]]:
) -> Mistral:
"""Generates Mistral instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Falcon model. Has to be one of Falcon.get_model_names().
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Callable[[Iterable[str]], Iterable[str]]): Falcon instance that can execute a set of prompts and return
the raw responses.
RETURNS (Mistral): Mistral instance that can execute a set of prompts and return the raw responses.
"""
return Mistral(
name=name, config_init=config_init, config_run=config_run, context_length=8000
Expand Down
128 changes: 128 additions & 0 deletions spacy_llm/models/hf/yi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

from confection import SimpleFrozenDict

from ...compat import Literal, transformers
from ...registry.util import registry
from .base import HuggingFace


class Yi(HuggingFace):
MODEL_NAMES = Literal[ # noqa: F722
"Yi-34B",
"Yi-34B-chat-8bits",
"Yi-6B-chat",
"Yi-6B",
"Yi-6B-200K",
"Yi-34B-chat",
"Yi-34B-chat-4bits",
"Yi-34B-200K",
]

def __init__(
self,
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: int,
):
self._tokenizer: Optional["transformers.AutoTokenizer"] = None
self._is_instruct = "instruct" in name
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)

# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name, **self._config_run
)
# To avoid deprecation warning regarding usage of `max_length`.
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length

def init_model(self) -> Any:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
self._name, use_fast=False
)
init_cfg = self._config_init
device: Optional[str] = None
if "device" in init_cfg:
device = init_cfg.pop("device")

model = transformers.AutoModelForCausalLM.from_pretrained(
self._name, **init_cfg, resume_download=True
).eval()
if device:
model.to(device)

return model

@property
def hf_account(self) -> str:
return "01-ai"

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
assert hasattr(self._model, "generate")
assert hasattr(self._tokenizer, "apply_chat_template")
assert self._tokenizer

responses: List[List[str]] = []

for prompts_for_doc in prompts:
prompts_for_doc = list(prompts_for_doc)

tokenized_input_ids = [
self._tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": prompt}],
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
for prompt in prompts_for_doc
]
tokenized_input_ids = [
tp.to(self._model.device) for tp in tokenized_input_ids
]

responses.append(
[
self._tokenizer.decode(
self._model.generate(
input_ids=tok_ii, generation_config=self._hf_config_run
)[:, tok_ii.shape[1] :][0],
skip_special_tokens=True,
).strip("\n")
for tok_ii in tokenized_input_ids
]
)

return responses

@staticmethod
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs()
return {**default_cfg_init, **{"torch_dtype": "auto"}}, default_cfg_run


@registry.llm_models("spacy.Yi.v1")
def yi_hf(
name: Yi.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Yi:
"""Generates Yi instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names().
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses.
"""
return Yi(
name=name,
config_init=config_init,
config_run=config_run,
context_length=200000 if "200K" in name else 32000,
)
101 changes: 101 additions & 0 deletions spacy_llm/models/hf/zephyr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

from confection import SimpleFrozenDict

from ...compat import Literal, transformers
from ...registry.util import registry
from .base import HuggingFace


class Zephyr(HuggingFace):
MODEL_NAMES = Literal["zephyr-7b-beta"] # noqa: F722

def __init__(
self,
name: MODEL_NAMES,
config_init: Optional[Dict[str, Any]],
config_run: Optional[Dict[str, Any]],
context_length: int,
):
super().__init__(
name=name,
config_init=config_init,
config_run=config_run,
context_length=context_length,
)

# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name, **self._config_run
)
# To avoid deprecation warning regarding usage of `max_length`.
self._hf_config_run.max_new_tokens = self._hf_config_run.max_length

def init_model(self) -> Any:
return transformers.pipeline(
"text-generation",
model=self._name,
return_full_text=False,
**self._config_init
)

@property
def hf_account(self) -> str:
return "HuggingFaceH4"

def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override]
responses: List[List[str]] = []

for prompts_for_doc in prompts:
formatted_prompts_for_doc = [
self._model.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False,
add_generation_prompt=False,
)
for prompt in prompts_for_doc
]

responses.append(
[
self._model(prompt, generation_config=self._hf_config_run)[0][
"generated_text"
]
.replace("<|assistant|>", "")
.strip("\n")
for prompt in formatted_prompts_for_doc
]
)

return responses

@staticmethod
def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs()
return default_cfg_init, {
**default_cfg_run,
**{
"max_new_tokens": 256,
"do_sample": True,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.95,
},
}


@registry.llm_models("spacy.Zephyr.v1")
def zephyr_hf(
name: Zephyr.MODEL_NAMES,
config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(),
config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(),
) -> Zephyr:
"""Generates Zephyr instance that can execute a set of prompts and return the raw responses.
name (Literal): Name of the Zephyr model. Has to be one of Zephyr.get_model_names().
config_init (Optional[Dict[str, Any]]): HF config for initializing the model.
config_run (Optional[Dict[str, Any]]): HF config for running the model.
RETURNS (Zephyr): Zephyr instance that can execute a set of prompts and return the raw responses.
"""
return Zephyr(
name=name, config_init=config_init, config_run=config_run, context_length=8000
)
2 changes: 1 addition & 1 deletion spacy_llm/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def make_llm(
nlp (Language): Pipeline.
name (str): The component instance name, used to add entries to the
losses during training.
task (Optional[LLMTask]): An LLMTask can generate prompts for given docs, and can parse the LLM's responses into
task (Optional[_LLMTask]): An _LLMTask can generate prompts for given docs, and can parse the LLM's responses into
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
structured information and set that back on the docs.
model (Callable[[Iterable[Any]], Iterable[Any]]]): Callable querying the specified LLM API.
cache (Cache): Cache to use for caching prompts and responses per doc (batch).
Expand Down
8 changes: 7 additions & 1 deletion spacy_llm/tasks/entity_linker/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,10 @@ def reduce_shards_to_doc(task: EntityLinkerTask, shards: Iterable[Doc]) -> Doc:
RETURNS (Doc): Fused doc instance.
"""
# Entities are additive, so we can just merge shards.
return Doc.from_docs(list(shards), ensure_whitespace=True)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message=".*Skipping .* while merging docs.",
)
return Doc.from_docs(list(shards), ensure_whitespace=True)
Comment on lines -209 to +215
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this edit is coming from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just a drive-by because I noticed the warnings filter is missing here 🙃 I can move this into a separate PR, if you mind having it in here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's all just a bit confusing with the huge (mostly unrelated) git history etc - I do in general appreciate more "atomic" PRs ;-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do in general appreciate more "atomic" PRs ;-)

I know 🫣

It's all just a bit confusing with the huge (mostly unrelated) git history etc

Yeah, I don't know why that's the case. The branches should all be updated.

7 changes: 6 additions & 1 deletion spacy_llm/tasks/sentiment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def reduce_shards_to_doc(task: SentimentTask, shards: Iterable[Doc]) -> Doc:
setattr(
doc._,
task.field,
sum([score * weight for score, weight in zip(sent_scores, weights)]),
sum(
[
(score if score else 0) * weight
for score, weight in zip(sent_scores, weights)
]
),
)

return doc
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/models/test_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

[components.llm]
factory = "llm"
save_io = True

[components.llm.task]
@llm_tasks = "spacy.NoOp.v1"
Expand Down
68 changes: 68 additions & 0 deletions spacy_llm/tests/models/test_yi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import copy

import pytest
import spacy
from confection import Config # type: ignore[import]
from thinc.compat import has_torch_cuda_gpu

from ...compat import torch

_PIPE_CFG = {
"model": {
"@llm_models": "spacy.Yi.v1",
"name": "Yi-6B-chat",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
}

_NLP_CONFIG = """

[nlp]
lang = "en"
pipeline = ["llm"]
batch_size = 128

[components]

[components.llm]
factory = "llm"

[components.llm.task]
@llm_tasks = "spacy.NoOp.v1"

[components.llm.model]
@llm_models = "spacy.Yi.v1"
name = "Yi-6B"
"""


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")
torch.cuda.empty_cache()


@pytest.mark.gpu
@pytest.mark.skip(reason="CI runner needs more GPU memory")
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_init_from_config():
orig_config = Config().from_str(_NLP_CONFIG)
nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True)
assert nlp.pipe_names == ["llm"]
torch.cuda.empty_cache()


@pytest.mark.gpu
@pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA")
def test_invalid_model():
orig_config = Config().from_str(_NLP_CONFIG)
config = copy.deepcopy(orig_config)
config["components"]["llm"]["model"]["name"] = "x"
with pytest.raises(ValueError, match="unexpected value; permitted"):
spacy.util.load_model_from_config(config, auto_fill=True)
torch.cuda.empty_cache()
Loading