Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Mar 23, 2024
1 parent 24acd90 commit 837a2a8
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions vlmeval/vlm/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def __init__(self, model_pth='llava-hf/llava-v1.6-vicuna-7b-hf', **kwargs):
assert version_cmp(transformers.__version__, '4.39.0', 'ge')
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
self.model_pth = model_pth
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth)
if '34b' in model_pth.lower():
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth, use_fast=False)
else:
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth)
model = LlavaNextForConditionalGeneration.from_pretrained(
self.model_pth, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = model.eval()
Expand Down Expand Up @@ -202,15 +205,15 @@ def generate(self, image_path, prompt, dataset=None):
output = self.model.generate(**inputs, **self.kwargs)
answer = self.processor.decode(output[0], skip_special_token=True)
if '<s>' in answer:
answer = answer.replace('<s>', '')
answer = answer.strip()

lt = len(prompt_wtmpl)
if prompt_wtmpl == answer[:lt]:
answer = answer[lt:]
answer = answer.replace('<s>', '').strip()
if '[/INST]' in answer:
answer = answer.split('[/INST]')[1].strip()
elif 'ASSISTANT:' in answer:
answer = answer.split('ASSISTANT:')[1].strip()
elif '<|im_start|>assistant\n' in answer:
answer = answer.split('<|im_start|>assistant\n')[1].strip()

if '</s>' in answer:
answer = answer.split('</s>')[0]
answer = answer.strip()
answer = answer.split('</s>')[0].strip()

return answer

0 comments on commit 837a2a8

Please sign in to comment.