Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anyres compatible fine-tuning of llava-1.6 mistral 7b and 34b #1347

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llava/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,17 @@ def dict(self):
sep="<|im_end|>",
)

conv_chatml_direct_ft = Conversation(
system="""<|im_start|>system\nAnswer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)


default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -378,6 +389,7 @@ def dict(self):
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"chatml_direct_ft": conv_chatml_direct_ft,
"mistral_direct": conv_chatml_direct,

"plain": conv_llava_plain,
Expand Down
5 changes: 5 additions & 0 deletions llava/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ def process_images(images, image_processor, model_cfg):
return new_images


def train_process_images(images, image_processor, model_cfg):
new_image = process_anyres_image(images, image_processor, model_cfg.image_grid_pinpoints)
return new_image


def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

Expand Down
19 changes: 15 additions & 4 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,18 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
if 'lora' in model_name.lower() and model_base is None:
warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
if 'lora' in model_name.lower() and model_base is not None:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
if 'mistral' in model_name.lower():
from llava.model.language_model.llava_mistral import LlavaMistralConfig
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)

else:
from llava.model.language_model.llava_llama import LlavaConfig
lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
print('Loading LLaVA from base model...')
model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
if model.lm_head.weight.shape[0] != token_num:
model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
Expand Down Expand Up @@ -93,6 +100,10 @@ def load_from_hf(repo_id, filename, subfolder=None):
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
elif 'mistral' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_base)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
Expand Down
99 changes: 83 additions & 16 deletions llava/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token
from llava.mm_utils import tokenizer_image_token, train_process_images

from PIL import Image

Expand Down Expand Up @@ -73,7 +73,7 @@ class DataArguments:
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = 'square'
image_aspect_ratio: str = 'anyres'


@dataclass
Expand Down Expand Up @@ -497,6 +497,47 @@ def preprocess_v1(
)


def debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
calculated_len = 0

rounds = conversation.split(conv.sep)
print("Tokenized Conversation:")
tokenized_conversation = []
for rou in rounds:
if has_image:
tokenized_rou = tokenizer_image_token(rou, tokenizer)
else:
tokenized_rou = tokenizer.encode(rou, add_special_tokens=False)
print(tokenized_rou)
tokenized_conversation.extend(tokenized_rou)
calculated_len += len(tokenized_rou)

print("\nTokenized Target:")
tokenized_target = target[target != IGNORE_INDEX].tolist()
print(tokenized_target)

print("\nMissing Tokens:")
missing_tokens = []
conv_idx = 0
for i, token in enumerate(tokenized_target):
if conv_idx >= len(tokenized_conversation) or token != tokenized_conversation[conv_idx]:
missing_tokens.append((i, token))
else:
conv_idx += 1

if missing_tokens:
for idx, token in missing_tokens:
print(f"Position: {idx}, Token: {token} ({tokenizer.decode([token])})")
else:
print("No missing tokens found.")

if calculated_len != total_len:
print(f"\nLength mismatch detected. Calculated: {calculated_len}, Actual: {total_len}")
else:
print(f"\nLengths match. Length: {calculated_len}")


def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
Expand All @@ -505,11 +546,9 @@ def preprocess_mpt(
conv = conversation_lib.default_conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]

conv.messages = []
Expand All @@ -518,11 +557,10 @@ def preprocess_mpt(
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())

# Tokenize conversations
#print(conv.get_prompt())

if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
Expand All @@ -531,14 +569,15 @@ def preprocess_mpt(
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids

targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT

# Mask targets
sep = conv.sep + conv.roles[1]
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
#print("target: ", target)
#print("conversation: ", conversation)

rounds = conversation.split(conv.sep)
re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
Expand All @@ -547,13 +586,14 @@ def preprocess_mpt(
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
#print(rou)
if rou == "":
break

parts = rou.split(sep)
if len(parts) != 2:
break
parts[0] += sep
#print("parts ", parts)

if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
Expand All @@ -562,14 +602,18 @@ def preprocess_mpt(
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 1

if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
#if i != 0 and getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
if getattr(tokenizer, 'legacy', False) and IS_TOKENIZER_GREATER_THAN_0_14:
#print("yes")
round_len += 1
instruction_len += 1

target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# debug_34b_tokenization_length(conversation, target, tokenizer, conv, has_image)

if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
Expand Down Expand Up @@ -660,14 +704,16 @@ class LazySupervisedDataset(Dataset):

def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
data_args: DataArguments,
model_config):
super(LazySupervisedDataset, self).__init__()
list_data_dict = json.load(open(data_path, "r"))

rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
self.model_config = model_config

def __len__(self):
return len(self.list_data_dict)
Expand Down Expand Up @@ -699,6 +745,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
image_size = image.size
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
Expand All @@ -714,6 +761,8 @@ def expand2square(pil_img, background_color):
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
elif self.data_args.image_aspect_ratio == 'anyres':
image = train_process_images(image, processor, self.model_config)
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
Expand All @@ -732,6 +781,7 @@ def expand2square(pil_img, background_color):
# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['image'] = image
data_dict['image_size'] = image_size
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
Expand Down Expand Up @@ -765,20 +815,24 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
image_sizes = [instance['image_size'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
batch['image_sizes'] = image_sizes

return batch


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
data_args,
model_config) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_args=data_args,
model_config=model_config)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
Expand Down Expand Up @@ -823,12 +877,20 @@ def train(attn_implementation=None):
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
elif 'mistral' in model_args.model_name_or_path.lower():
model = LlavaMistralForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16),
**bnb_model_from_pretrained_args
)
else:
model = LlavaLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
torch_dtype=(torch.bfloat16 if training_args.bf16 else torch.float16),
**bnb_model_from_pretrained_args
)
else:
Expand Down Expand Up @@ -943,6 +1005,7 @@ def make_inputs_require_grad(module, input, output):
model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)


if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
Expand All @@ -956,8 +1019,11 @@ def make_inputs_require_grad(module, input, output):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)

model.resize_token_embeddings(len(tokenizer))

data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
data_args=data_args,
model_config=model.config)
trainer = LLaVATrainer(model=model,
tokenizer=tokenizer,
args=training_args,
Expand Down Expand Up @@ -989,3 +1055,4 @@ def make_inputs_require_grad(module, input, output):

if __name__ == "__main__":
train()

39 changes: 39 additions & 0 deletions scripts/v1_6/finetune_lora_llava_34b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

deepspeed llava/train/train_mem.py \
--lora_enable True --lora_r 16 --lora_alpha 32 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero3.json \
--model_name_or_path liuhaotian/llava-v1.6-34b \
--version chatml_direct_ft \
--data_path transformed_data.json \
--image_folder train_images \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--mm_patch_merge_type spatial_unpad \
--image_aspect_ratio anyres \
--group_by_modality_length False \
--bf16 True \
--fp16 False \
--output_dir ./llava-lora-34b \
--num_train_epochs 1 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 250 \
--save_total_limit 5 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.05 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
Loading