Skip to content

Commit

Permalink
Merge branch '3.3_super'
Browse files Browse the repository at this point in the history
  • Loading branch information
imoneoi committed Sep 4, 2023
2 parents 83a683c + f9995eb commit 2262adb
Show file tree
Hide file tree
Showing 12 changed files with 364 additions and 640 deletions.
121 changes: 34 additions & 87 deletions README.md

Large diffs are not rendered by default.

127 changes: 25 additions & 102 deletions ochat/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ class ModelConfig:

condition_fn: Optional[Callable] = None

# Label
group_fn: Optional[Callable] = None
num_groups: int = 1

# Model
model_max_context: Optional[int] = None
model_create: Optional[Callable] = None
Expand All @@ -32,24 +28,31 @@ class ModelConfig:
def generate_conversation_template(self, tokenize_fn, tokenize_special_fn, system_prompt, message_list, message_props=None):
tokens = []
masks = []
weights = []

# begin of sentence (bos)
if self.bos_token:
t = tokenize_special_fn(self.bos_token)
tokens.append(t)
masks.append(False)

tokens.extend([t])
masks.extend([False])
weights.extend([0.])

# Condition
if self.condition_fn is not None:
t = tokenize_fn(self.condition_fn(message_props)) + [tokenize_special_fn(self.eot_token)]

tokens.extend(t)
masks.extend([False] * len(t))
weights.extend([0.] * len(t))

# System
if system_prompt:
t = tokenize_fn(system_prompt) + [tokenize_special_fn(self.eot_token)]

tokens.extend(t)
masks.extend([False] * len(t))
weights.extend([0.] * len(t))

# Messages
for idx, message in enumerate(message_list):
Expand All @@ -62,20 +65,26 @@ def generate_conversation_template(self, tokenize_fn, tokenize_special_fn, syste
t = tokenize_fn(role_prefix)
tokens.extend(t)
masks.extend([False] * len(t))
weights.extend([0.] * len(t))

# Message
if "value" in message:
t = tokenize_fn(message["value"]) + [tokenize_special_fn(self.eot_token)]

# determine weights
use_loss = (message["from"] == self.ai_role) and bool(message.get("use_loss", True))
w = 1.0 if use_loss else 0.0

if message_props is not None and ("weight" in message_props):
w *= message_props["weight"]

tokens.extend(t)
masks.extend([message["from"] == self.ai_role] * len(t))
masks.extend([use_loss] * len(t))
weights.extend([w] * len(t))
else:
assert idx == len(message_list) - 1, "Empty message for completion must be on the last."

group = 0
if self.group_fn:
group = self.group_fn(message_props)

return tokens, masks, group
return tokens, masks, weights


def _v2_conditional_prefix(from_role, props):
Expand Down Expand Up @@ -109,13 +118,6 @@ def _v3_2_conditional_prefix(from_role, props):
return prefixes[from_role]


def _v2_v3_group(props):
if props is None:
return 1

return 1 if props["is_gpt4"] else 0


def _v3_condition(props):
gpt4_condition = "Assistant is GPT4"
gpt3_condition = "Assistant is GPT3"
Expand All @@ -137,18 +139,14 @@ def _v3_condition(props):
eot_token="<|end_of_turn|>",
bos_token="<s>",

# Label
group_fn=_v2_v3_group,
num_groups=2,

# Tokenize
model_max_context=4096,
model_create=partial(ochat.models.LlamaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
legacy=True),
),

"openchat_v3.1_llama2": ModelConfig(
Expand All @@ -165,18 +163,14 @@ def _v3_condition(props):

condition_fn=_v3_condition,

# Label
group_fn=_v2_v3_group,
num_groups=2,

# Tokenize
model_max_context=4096,
model_create=partial(ochat.models.LlamaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
legacy=True),
),

# OpenChat V2
Expand All @@ -189,18 +183,14 @@ def _v3_condition(props):
eot_token="<|end_of_turn|>",
bos_token="<s>",

# Label
group_fn=_v2_v3_group,
num_groups=2,

# Tokenize
model_max_context=2048,
model_create=partial(ochat.models.LlamaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
legacy=True),
),

# OpenChat
Expand All @@ -223,73 +213,6 @@ def _v3_condition(props):
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
),

"openchat": ModelConfig(
name="OpenChat",

# Prompt
role_prefix={
"human": "Human: ",
"gpt": "Assistant: "
},
ai_role="gpt",
eot_token="<|end_of_turn|>",
bos_token="<s>",

# Tokenize
model_max_context=2048,
model_create=partial(ochat.models.LlamaForCausalLM.from_pretrained,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
),

# OpenChat 8192
"openchat_8192": ModelConfig(
name="OpenChat_8192",

# Prompt
role_prefix={
"human": "Human: ",
"gpt": "Assistant: "
},
ai_role="gpt",
eot_token="<|end_of_turn|>",
bos_token="<s>",

# Model
model_max_context=8192,
model_create=partial(ochat.models.LlamaForCausalLM.from_pretrained,
extend_context_to=8192,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16),
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True),
),

# OpenCoder / OpenCoderPlus
"opencoder": ModelConfig(
name="OpenCoder",

# Prompt
role_prefix={
"human": "User:",
"gpt": "Assistant:"
},
ai_role="gpt",
eot_token="<|end_of_turn|>",
bos_token=None,

# Tokenize
model_max_context=8192,
model_create=None, # TODO: StarCoder Unpadded
model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained,
use_fast=False,
use_auth_token=True)
legacy=True),
)
}
29 changes: 21 additions & 8 deletions ochat/data/clean_sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def sample_load(filename):
############### [2] HTML cleaning


blocked_words = ["openai"]
# blocked_words = ["openai"]

div_pattern = re.compile("<div.*?>")
span_pattern = re.compile("<span.*?>")
Expand Down Expand Up @@ -91,11 +91,15 @@ def html_to_markdown(val: str) -> str:
return val


def contain_blocked_words(val: str) -> bool:
for w in blocked_words:
if w in val.lower():
return True
return False
# def contain_blocked_words(val: str) -> bool:
# for w in blocked_words:
# if w in val.lower():
# return True
# return False


def remove_whitespace_and_non_printable(s: str) -> str:
return "".join([c for c in s if c.isprintable() and not c.isspace()])


def sample_clean_html(sample):
Expand All @@ -115,20 +119,29 @@ def sample_clean_html(sample):
if len(sample["items"]) <= 1:
raise DataPipelineError("Conversation too short")

char_count = 0
for i, c in enumerate(sample["items"]):
if c["from"] != roles[i % 2]:
raise DataPipelineError("Wrong format")

if contain_blocked_words(c["value"]):
raise DataPipelineError("Contain blocked words")
# if contain_blocked_words(c["value"]):
# raise DataPipelineError("Contain blocked words")

try:
new_val = html_to_markdown(c["value"])
except (bs4.builder.ParserRejectedMarkup, AssertionError):
raise DataPipelineError("Parser error")

# Filter empty answers like https://sharegpt.com/c/mrllZ6u
if not len(remove_whitespace_and_non_printable(new_val)):
raise DataPipelineError("Empty answer")

char_count += len(new_val)
c["value"] = new_val

if char_count < 16:
raise DataPipelineError("Conversation too short")

return sample


Expand Down
19 changes: 13 additions & 6 deletions ochat/data/filter_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _split(a, n):


@ray.remote
def filter_conversation_batch(fasttext_model: str, lang_list: list, batch: list):
def filter_conversation_batch(fasttext_model: str, keep_lang: list, skip_lang: list, batch: list):
model = fasttext.load_model(fasttext_model)

result = []
Expand All @@ -39,13 +39,17 @@ def filter_conversation_batch(fasttext_model: str, lang_list: list, batch: list)
lang_freq.setdefault(lang, 0)
lang_freq[lang] += 1

if lang in lang_list:
result.append(conversation)
if lang in skip_lang:
continue
if keep_lang and (lang not in keep_lang):
continue

result.append(conversation)

return result, lang_freq


def filter_lang(fasttext_model: str, lang_list: str, in_file: str, out_file: str, num_cpus: int = os.cpu_count()):
def filter_lang(fasttext_model: str, keep_lang: list, skip_lang: list, in_file: str, out_file: str, num_cpus: int = os.cpu_count()):
# load conversations
with open(in_file, "r") as f:
conversations = json.load(f)
Expand All @@ -55,7 +59,8 @@ def filter_lang(fasttext_model: str, lang_list: str, in_file: str, out_file: str

handles = [filter_conversation_batch.remote(
fasttext_model=fasttext_model,
lang_list=lang_list,
keep_lang=keep_lang,
skip_lang=skip_lang,
batch=batch
) for batch in _split(conversations, num_cpus)]

Expand All @@ -74,6 +79,7 @@ def filter_lang(fasttext_model: str, lang_list: str, in_file: str, out_file: str
json.dump(results, f)

# show statistics
print(f"Total {len(conversations)} Keep {len(results)}")
pprint(total_lang_freq)

ray.shutdown()
Expand All @@ -82,7 +88,8 @@ def filter_lang(fasttext_model: str, lang_list: str, in_file: str, out_file: str
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--fasttext-model", type=str, required=True)
parser.add_argument("--lang-list", type=str, nargs='+', default=["en"])
parser.add_argument("--keep-lang", type=str, nargs='+', default=["en"])
parser.add_argument("--skip-lang", type=str, nargs='+', default=[])

parser.add_argument("--in-file", type=str, required=True)
parser.add_argument("--out-file", type=str, required=True)
Expand Down
Loading

0 comments on commit 2262adb

Please sign in to comment.