diff --git a/llava/eval/eval_science_qa.py b/llava/eval/eval_science_qa.py index ccf206bbd..866e83eb5 100644 --- a/llava/eval/eval_science_qa.py +++ b/llava/eval/eval_science_qa.py @@ -17,62 +17,42 @@ def get_args(): def convert_caps(results): - fakecaps = [] - for result in results: - image_id = result['question_id'] - caption = result['text'] - fakecaps.append({"image_id": int(image_id), "caption": caption}) - return fakecaps + return [{"image_id": int(result['question_id']), "caption": result['text']} for result in results] def get_pred_idx(prediction, choices, options): - """ - Get the index (e.g. 2) from the prediction (e.g. 'C') - """ - if prediction in options[:len(choices)]: - return options.index(prediction) - else: - return -1 - return random.choice(range(len(choices))) + return options.index(prediction) if prediction in options[:len(choices)] else random.choice(range(len(choices))) if __name__ == "__main__": args = get_args() base_dir = args.base_dir - split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] - problems = json.load(open(os.path.join(base_dir, "problems.json"))) - predictions = [json.loads(line) for line in open(args.result_file)] - predictions = {pred['question_id']: pred for pred in predictions} + with open(os.path.join(base_dir, "pid_splits.json")) as f: + split_indices = set(json.load(f)[args.split]) + with open(os.path.join(base_dir, "problems.json")) as f: + problems = json.load(f) + with open(args.result_file) as f: + predictions = {pred['question_id']: json.loads(line) for line in f} + split_problems = {idx: problems[idx] for idx in split_indices} results = {'correct': [], 'incorrect': []} - sqa_results = {} - sqa_results['acc'] = None - sqa_results['correct'] = None - sqa_results['count'] = None - sqa_results['results'] = {} - sqa_results['outputs'] = {} + sqa_results = {'acc': None, 'correct': None, 'count': None, 'results': {}, 'outputs': {}} + + pattern = re.compile(r'The answer is ([A-Z]).') for prob_id, prob in split_problems.items(): - if prob_id not in predictions: - pred = {'text': 'FAILED', 'prompt': 'Unknown'} - pred_text = 'FAILED' - else: - pred = predictions[prob_id] - pred_text = pred['text'] + pred = predictions.get(prob_id, {'text': 'FAILED', 'prompt': 'Unknown'}) + pred_text = pred['text'] if pred_text in args.options: answer = pred_text elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": answer = pred_text[0] else: - pattern = re.compile(r'The answer is ([A-Z]).') res = pattern.findall(pred_text) - if len(res) == 1: - answer = res[0] # 'A', 'B', ... - else: - answer = "FAILED" + answer = res[0] if len(res) == 1 else "FAILED" pred_idx = get_pred_idx(answer, prob['choices'], args.options) @@ -85,22 +65,16 @@ def get_pred_idx(prediction, choices, options): 'is_multimodal': '' in pred['prompt'], } - sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) + sqa_results['results'][prob_id] = pred_idx sqa_results['outputs'][prob_id] = pred_text - if pred_idx == prob['answer']: - results['correct'].append(analysis) - else: - results['incorrect'].append(analysis) + (results['correct'] if pred_idx == prob['answer'] else results['incorrect']).append(analysis) correct = len(results['correct']) - total = len(results['correct']) + len(results['incorrect']) + total = correct + len(results['incorrect']) - ###### IMG ###### - multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) - multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) - multimodal_total = multimodal_correct + multimodal_incorrect - ###### IMG ###### + multimodal_correct = sum(1 for x in results['correct'] if x['is_multimodal']) + multimodal_total = multimodal_correct + sum(1 for x in results['incorrect'] if x['is_multimodal']) print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') @@ -108,7 +82,4 @@ def get_pred_idx(prediction, choices, options): sqa_results['correct'] = correct sqa_results['count'] = total - with open(args.output_file, 'w') as f: - json.dump(results, f, indent=2) - with open(args.output_result, 'w') as f: - json.dump(sqa_results, f, indent=2) + with open(args.output_file, 'w') as f