From 9acd18531ca9371f8e4156100c2f95acb71f381f Mon Sep 17 00:00:00 2001 From: Jayant Date: Fri, 1 Mar 2024 16:16:28 +0000 Subject: [PATCH 1/3] enabling mistral training by adding mistral template in ./llava/conversation.py --- llava/conversation.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..8d5a8f817 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -357,7 +357,15 @@ def dict(self): sep="", sep2="", ) - +conv_biomistral = Conversation( + system=""" [INST] system Answer the questions.""", + roles=("[INST]user\n", "[/INST]assistant\n"), + version="mistral", + messages=(), + offset=0, + sep_style=SeparatorStyle.biomistral, + sep="", +) conv_chatml_direct = Conversation( system="""<|im_start|>system Answer the questions.""", @@ -379,7 +387,7 @@ def dict(self): "mistral_instruct": conv_mistral_instruct, "chatml_direct": conv_chatml_direct, "mistral_direct": conv_chatml_direct, - + "mistral_bio":conv_biomistral, "plain": conv_llava_plain, "v0_plain": conv_llava_plain, "llava_v0": conv_llava_v0, From 4e06e9473bee3c642f3ea16aaa730cbcbe20ed06 Mon Sep 17 00:00:00 2001 From: Jayant Date: Fri, 1 Mar 2024 16:36:20 +0000 Subject: [PATCH 2/3] enabling mistral training by adding mistral template in ./llava/conversation.py --- llava/conversation.py | 2 +- llava/train/train.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/llava/conversation.py b/llava/conversation.py index 8d5a8f817..0eb0836b6 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -13,7 +13,7 @@ class SeparatorStyle(Enum): MPT = auto() PLAIN = auto() LLAMA_2 = auto() - + biomistral =auto() @dataclasses.dataclass class Conversation: diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..95fbaaa44 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -793,7 +793,7 @@ def train(attn_implementation=None): model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) - + model_args.version='mistral_bio' bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig @@ -890,7 +890,6 @@ def make_inputs_require_grad(module, input, output): padding_side="right", use_fast=False, ) - if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( @@ -942,7 +941,7 @@ def make_inputs_require_grad(module, input, output): training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) - + if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): From 27f4243de9943d26c161088ebe763e65e4b07fe2 Mon Sep 17 00:00:00 2001 From: Jayant Date: Fri, 1 Mar 2024 16:37:59 +0000 Subject: [PATCH 3/3] remove print --- llava/train/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llava/train/train.py b/llava/train/train.py index 95fbaaa44..477c668b6 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -793,7 +793,7 @@ def train(attn_implementation=None): model_args, data_args, training_args = parser.parse_args_into_dataclasses() local_rank = training_args.local_rank compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) - model_args.version='mistral_bio' + bnb_model_from_pretrained_args = {} if training_args.bits in [4, 8]: from transformers import BitsAndBytesConfig @@ -890,6 +890,7 @@ def make_inputs_require_grad(module, input, output): padding_side="right", use_fast=False, ) + if model_args.version == "v0": if tokenizer.pad_token is None: smart_tokenizer_and_embedding_resize( @@ -941,7 +942,7 @@ def make_inputs_require_grad(module, input, output): training_args.use_im_start_end = model_args.mm_use_im_start_end model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) - + if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules():