diff --git a/llava/conversation.py b/llava/conversation.py index 00c56867d..d60935389 100644 --- a/llava/conversation.py +++ b/llava/conversation.py @@ -369,6 +369,17 @@ def dict(self): sep="<|im_end|>", ) +conv_chatml_direct_ft = Conversation( + system="""<|im_start|>system\nAnswer the questions.""", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + version="mpt", + messages=(), + offset=0, + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>", +) + + default_conversation = conv_vicuna_v1 conv_templates = { "default": conv_vicuna_v0, @@ -378,6 +389,7 @@ def dict(self): "llama_2": conv_llama_2, "mistral_instruct": conv_mistral_instruct, "chatml_direct": conv_chatml_direct, + "chatml_direct_ft": conv_chatml_direct_ft, "mistral_direct": conv_chatml_direct, "plain": conv_llava_plain, diff --git a/llava/mm_utils.py b/llava/mm_utils.py index de97345cf..b828ae26e 100644 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -182,6 +182,11 @@ def process_images(images, image_processor, model_cfg): return new_images +def train_process_images(images, image_processor, model_cfg): + new_image = process_anyres_image(images, image_processor, model_cfg.image_grid_pinpoints) + return new_image + + def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] diff --git a/llava/model/builder.py b/llava/model/builder.py index e3d50829f..7d9452214 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -50,11 +50,18 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l if 'lora' in model_name.lower() and model_base is None: warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') if 'lora' in model_name.lower() and model_base is not None: - from llava.model.language_model.llava_llama import LlavaConfig - lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) - tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + if 'mistral' in model_name.lower(): + from llava.model.language_model.llava_mistral import LlavaMistralConfig + lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) + + else: + from llava.model.language_model.llava_llama import LlavaConfig + lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) + model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) print('Loading LLaVA from base model...') - model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features if model.lm_head.weight.shape[0] != token_num: model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) @@ -93,6 +100,10 @@ def load_from_hf(repo_id, filename, subfolder=None): tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) + elif 'mistral' in model_name.lower(): + tokenizer = AutoTokenizer.from_pretrained(model_base) + cfg_pretrained = AutoConfig.from_pretrained(model_path) + model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) else: tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) cfg_pretrained = AutoConfig.from_pretrained(model_path) diff --git a/llava/train/train.py b/llava/train/train.py index 477c668b6..9a88c5f47 100644 --- a/llava/train/train.py +++ b/llava/train/train.py @@ -33,7 +33,7 @@ from llava import conversation as conversation_lib from llava.model import * -from llava.mm_utils import tokenizer_image_token +from llava.mm_utils import tokenizer_image_token, train_process_images from PIL import Image @@ -73,7 +73,7 @@ class DataArguments: lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default=None) - image_aspect_ratio: str = 'square' + image_aspect_ratio: str = 'anyres' @dataclass @@ -497,6 +497,47 @@ def preprocess_v1( ) +def debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + calculated_len = 0 + + rounds = conversation.split(conv.sep) + print("Tokenized Conversation:") + tokenized_conversation = [] + for rou in rounds: + if has_image: + tokenized_rou = tokenizer_image_token(rou, tokenizer) + else: + tokenized_rou = tokenizer.encode(rou, add_special_tokens=False) + print(tokenized_rou) + tokenized_conversation.extend(tokenized_rou) + calculated_len += len(tokenized_rou) + + print("\nTokenized Target:") + tokenized_target = target[target != IGNORE_INDEX].tolist() + print(tokenized_target) + + print("\nMissing Tokens:") + missing_tokens = [] + conv_idx = 0 + for i, token in enumerate(tokenized_target): + if conv_idx >= len(tokenized_conversation) or token != tokenized_conversation[conv_idx]: + missing_tokens.append((i, token)) + else: + conv_idx += 1 + + if missing_tokens: + for idx, token in missing_tokens: + print(f"Position: {idx}, Token: {token} ({tokenizer.decode([token])})") + else: + print("No missing tokens found.") + + if calculated_len != total_len: + print(f"\nLength mismatch detected. Calculated: {calculated_len}, Actual: {total_len}") + else: + print(f"\nLengths match. Length: {calculated_len}") + + def preprocess_mpt( sources, tokenizer: transformers.PreTrainedTokenizer, @@ -505,11 +546,9 @@ def preprocess_mpt( conv = conversation_lib.default_conversation.copy() roles = {"human": conv.roles[0], "gpt": conv.roles[1]} - # Apply prompt templates conversations = [] for i, source in enumerate(sources): if roles[source[0]["from"]] != conv.roles[0]: - # Skip the first one if it is not from human source = source[1:] conv.messages = [] @@ -518,11 +557,10 @@ def preprocess_mpt( assert role == conv.roles[j % 2], f"{i}" conv.append_message(role, sentence["value"]) conversations.append(conv.get_prompt()) - - # Tokenize conversations + #print(conv.get_prompt()) if has_image: - input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) + input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0) else: input_ids = tokenizer( conversations, @@ -531,14 +569,15 @@ def preprocess_mpt( max_length=tokenizer.model_max_length, truncation=True, ).input_ids - + targets = input_ids.clone() assert conv.sep_style == conversation_lib.SeparatorStyle.MPT - # Mask targets sep = conv.sep + conv.roles[1] for conversation, target in zip(conversations, targets): total_len = int(target.ne(tokenizer.pad_token_id).sum()) + #print("target: ", target) + #print("conversation: ", conversation) rounds = conversation.split(conv.sep) re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt @@ -547,13 +586,14 @@ def preprocess_mpt( cur_len = 0 target[:cur_len] = IGNORE_INDEX for i, rou in enumerate(re_rounds): + #print(rou) if rou == "": break - parts = rou.split(sep) if len(parts) != 2: break parts[0] += sep + #print("parts ", parts) if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) @@ -562,7 +602,9 @@ def preprocess_mpt( round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 1 - if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + #if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + if getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14: + #print("yes") round_len += 1 instruction_len += 1 @@ -570,6 +612,8 @@ def preprocess_mpt( cur_len += round_len target[cur_len:] = IGNORE_INDEX + + # debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image) if cur_len < tokenizer.model_max_length: if cur_len != total_len: @@ -660,7 +704,8 @@ class LazySupervisedDataset(Dataset): def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, - data_args: DataArguments): + data_args: DataArguments, + model_config): super(LazySupervisedDataset, self).__init__() list_data_dict = json.load(open(data_path, "r")) @@ -668,6 +713,7 @@ def __init__(self, data_path: str, self.tokenizer = tokenizer self.list_data_dict = list_data_dict self.data_args = data_args + self.model_config = model_config def __len__(self): return len(self.list_data_dict) @@ -699,6 +745,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: image_folder = self.data_args.image_folder processor = self.data_args.image_processor image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') + image_size = image.size if self.data_args.image_aspect_ratio == 'pad': def expand2square(pil_img, background_color): width, height = pil_img.size @@ -714,6 +761,8 @@ def expand2square(pil_img, background_color): return result image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] + elif self.data_args.image_aspect_ratio == 'anyres': + image = train_process_images(image, processor, self.model_config) else: image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] sources = preprocess_multimodal( @@ -732,6 +781,7 @@ def expand2square(pil_img, background_color): # image exist in the data if 'image' in self.list_data_dict[i]: data_dict['image'] = image + data_dict['image_size'] = image_size elif self.data_args.is_multimodal: # image does not exist in the data, but the model is multimodal crop_size = self.data_args.image_processor.crop_size @@ -765,20 +815,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: if 'image' in instances[0]: images = [instance['image'] for instance in instances] + image_sizes = [instance['image_size'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images + batch['image_sizes'] = image_sizes return batch def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, - data_args) -> Dict: + data_args, + model_config) -> Dict: """Make dataset and collator for supervised fine-tuning.""" train_dataset = LazySupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path, - data_args=data_args) + data_args=data_args, + model_config=model_config) data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) return dict(train_dataset=train_dataset, eval_dataset=None, @@ -823,12 +877,20 @@ def train(attn_implementation=None): cache_dir=training_args.cache_dir, **bnb_model_from_pretrained_args ) + elif 'mistral' in model_args.model_name_or_path.lower(): + model = LlavaMistralForCausalLM.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + attn_implementation=attn_implementation, + torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16), + **bnb_model_from_pretrained_args + ) else: model = LlavaLlamaForCausalLM.from_pretrained( model_args.model_name_or_path, cache_dir=training_args.cache_dir, attn_implementation=attn_implementation, - torch_dtype=(torch.bfloat16 if training_args.bf16 else None), + torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16), **bnb_model_from_pretrained_args ) else: @@ -943,6 +1005,7 @@ def make_inputs_require_grad(module, input, output): model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer) + if training_args.bits in [4, 8]: from peft.tuners.lora import LoraLayer for name, module in model.named_modules(): @@ -956,8 +1019,11 @@ def make_inputs_require_grad(module, input, output): if training_args.bf16 and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) + model.resize_token_embeddings(len(tokenizer)) + data_module = make_supervised_data_module(tokenizer=tokenizer, - data_args=data_args) + data_args=data_args, + model_config=model.config) trainer = LLaVATrainer(model=model, tokenizer=tokenizer, args=training_args, @@ -989,3 +1055,4 @@ def make_inputs_require_grad(module, input, output): if __name__ == "__main__": train() + \ No newline at end of file diff --git a/scripts/v1_6/finetune_lora_llava_34b.sh b/scripts/v1_6/finetune_lora_llava_34b.sh new file mode 100644 index 000000000..5894e5ee3 --- /dev/null +++ b/scripts/v1_6/finetune_lora_llava_34b.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.6-34b \ + --version chatml_direct_ft \ + --data_path transformed_data.json \ + --image_folder train_images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --mm_patch_merge_type spatial_unpad \ + --image_aspect_ratio anyres \ + --group_by_modality_length False \ + --bf16 True \ + --fp16 False \ + --output_dir ./llava-lora-34b \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 250 \ + --save_total_limit 5 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.05 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ \ No newline at end of file diff --git a/scripts/v1_6/finetune_lora_llava_mistral.sh b/scripts/v1_6/finetune_lora_llava_mistral.sh new file mode 100644 index 000000000..f185fe8fe --- /dev/null +++ b/scripts/v1_6/finetune_lora_llava_mistral.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +deepspeed llava/train/train_mem.py \ + --lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path liuhaotian/llava-v1.6-mistral-7b \ + --version mistral_instruct \ + --data_path combined_data.json \ + --image_folder random_images \ + --vision_tower openai/clip-vit-large-patch14-336 \ + --mm_projector_type mlp2x_gelu \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --mm_patch_merge_type spatial_unpad \ + --image_aspect_ratio anyres \ + --group_by_modality_length False \ + --bf16 False \ + --fp16 True \ + --output_dir ./llava-lora-mistral \ + --num_train_epochs 1 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 500 \ + --save_total_limit 5 \ + --learning_rate 2e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.05 \ + --lr_scheduler_type "cosine" \ + --logging_steps 1 \ + --tf32 True \ + --model_max_length 4096 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb \ \ No newline at end of file