diff --git a/vlmeval/vlm/llava.py b/vlmeval/vlm/llava.py index 6097fbf9..cb54ee6f 100644 --- a/vlmeval/vlm/llava.py +++ b/vlmeval/vlm/llava.py @@ -201,9 +201,16 @@ def generate(self, image_path, prompt, dataset=None): inputs = self.processor(prompt_wtmpl, image, return_tensors='pt').to('cuda') output = self.model.generate(**inputs, **self.kwargs) answer = self.processor.decode(output[0], skip_special_token=True) + if '' in answer: + answer = answer.replace('', '') + answer = answer.strip() + lt = len(prompt_wtmpl) if prompt_wtmpl == answer[:lt]: answer = answer[lt:] - answer = answer.split('')[0] + + if '' in answer: + answer = answer.split('')[0] + answer = answer.strip() return answer