Skip to content

Commit

Permalink
Merge pull request #56 from imoneoi/3.5_mistral
Browse files Browse the repository at this point in the history
3.5 mistral
  • Loading branch information
imoneoi authored Nov 2, 2023
2 parents 0323d57 + a76cada commit f0d862e
Show file tree
Hide file tree
Showing 90 changed files with 16,211 additions and 1,211 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ wandb/

# Old
old/
temp/
profiler/

# Logs
logs/

# eval
eval_results/
evalplus_codegen/

# All datasets
dataset/
Expand Down
312 changes: 199 additions & 113 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/openchat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 60 additions & 0 deletions ochat/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from functools import partial

import torch
import transformers

from ochat.config.model_config import ModelConfig
from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
import ochat.models


_V3_2_PREFIXES = {
# OpenAI mapping

"user": "User:",
"assistant": "Assistant:"
}


def _v3_2_role_prefix(from_role, condition):
return f"{condition} {_V3_2_PREFIXES[from_role]}".strip()


MODEL_CONFIG_MAP = {
# OpenChat V3.2
"openchat_v3.2": ModelConfig(
# Model
model_max_context=4096,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
legacy=False),
model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=_v3_2_role_prefix,
eot="<|end_of_turn|>",
inference_condition="GPT4")
),

"openchat_v3.2_mistral": ModelConfig(
serving_aliases=("openchat_3.5", ),

# Model
model_max_context=8192,
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
legacy=True), # Mistral use legacy=True https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/tokenizer_config.json
model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),

# Conversation Template
conversation_template=partial(ConversationTemplate,
role_prefix=_v3_2_role_prefix,
eot="<|end_of_turn|>",
inference_condition="GPT4 Correct")
),
}
122 changes: 122 additions & 0 deletions ochat/config/conversation_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from typing import Optional, Callable, Iterable, List, Dict

from pydantic import BaseModel


class Message(BaseModel):
role: str
content: str

weight: Optional[float] = None


class Conversation(BaseModel):
items: List[Message]

condition: str = ""
system: str = ""


class ConversationTemplate(BaseModel):
tokenizer: Callable

# Prompt
role_prefix: Callable
eot: str

inference_condition: Optional[str] = None

# Private
bos_tokens_: List[int]
eot_tokens_: List[int]

def __init__(self, **data):
tokenizer = data["tokenizer"]
eot = data["eot"]
bos_tokens_ = tokenizer("").input_ids
eot_tokens_ = tokenizer(eot, add_special_tokens=False).input_ids

super().__init__(**data, bos_tokens_=bos_tokens_, eot_tokens_=eot_tokens_)

def _safe_tokenize(self, strings: Iterable[str]) -> List[List[int]]:
return self.tokenizer(strings, split_special_tokens=True, return_attention_mask=False, add_special_tokens=False).input_ids

def tokenize_conversations(self, conversations: Iterable[Conversation], inference: bool = False, seq_level_weight: bool = False):
# Pre-tokenize all conversations
default_condition = self.inference_condition if inference else ""

sys_mappings = set()
role_mappings = set()
all_text = []
for conv in conversations:
sys_mappings.add(conv.system)
for msg in conv.items:
role_mappings.add((msg.role, conv.condition or default_condition))
all_text.append(msg.content)

sys_mappings = list(sys_mappings)
role_mappings = list(role_mappings)

# Tokenize
sys_mappings = dict(zip(sys_mappings, self._safe_tokenize(sys_mappings)))
role_mappings = dict(zip(role_mappings, self._safe_tokenize([self.role_prefix(*args) for args in role_mappings])))
all_text = self._safe_tokenize(all_text)

# Convert
result_tokens = []
result_weights = []
all_text_idx = 0
for conv in conversations:
tokens = []
weights = []

# bos tokens
tokens.extend(self.bos_tokens_)
weights.extend([0.] * len(self.bos_tokens_))

# System
if conv.system:
system = sys_mappings[conv.system]
tokens.extend(system)
weights.extend([0.] * len(system))

tokens.extend(self.eot_tokens_)
weights.extend([0.] * len(self.eot_tokens_))

# Messages
last_idx = len(conv.items) - 1
for idx, msg in enumerate(conv.items):
# Prefix
role = role_mappings[(msg.role, conv.condition or default_condition)]
tokens.extend(role)
weights.extend([0.] * len(role))

# Message
text = all_text[all_text_idx]
all_text_idx += 1

# weight
w = None
if not inference:
assert msg.weight is not None

w = msg.weight
if seq_level_weight:
w /= len(text) + len(self.eot_tokens_)

# Message tokens
tokens.extend(text)
weights.extend([w] * len(text))

if not (inference and idx == last_idx): # Do not add EOT on last turn during inference
tokens.extend(self.eot_tokens_)
weights.extend([w] * len(self.eot_tokens_))

# Append result
result_tokens.append(tokens)
result_weights.append(weights)

# Sanity check
assert all_text_idx == len(all_text)

return result_tokens, result_weights
Loading

0 comments on commit f0d862e

Please sign in to comment.