diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 98d979baf..451bb62bb 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -32,18 +32,27 @@ @dataclass(frozen=True) class Example: - """An example from a dataset, containing input and output columns.""" + """An example from a dataset, containing input, explanation and output columns.""" input_col: str + explain_col: str output_col: str def __eq__(self, other) -> bool: """Example equality.""" - return self.input_col == other.input_col and self.output_col == other.output_col + return ( + self.input_col == other.input_col + and self.output_col == other.output_col + and self.explain_col == other.explain_col + ) # noqa E501 def __lt__(self, other) -> bool: """Example less than.""" - return self.input_col < other.input_col or self.output_col < other.output_col + return ( + self.input_col < other.input_col + or self.output_col < other.output_col + or self.explain_col < other.explain_col + ) # noqa E501 class PromptBasedDatasetGenerator(DatasetGenerator): @@ -169,7 +178,9 @@ def construct_prompt( ) for example in random_examples: low_quality_example_string += ( - f'input="{example.input_col}"\noutput="{example.output_col}"\n' + f'input="{example.input_col}"\n' + f'explanation="{example.explain_col}"\n' + f'output="{example.output_col}"\n' ) # To increase the diversity of the prompt to DatasetGenerator, create three # prompt templates, COMPLEX, MIDDLE, and SIMPLE. The COMPLEX template @@ -231,9 +242,11 @@ def apply_multi_vote_filtering( filtered_examples = [] input_output_map: dict[str, Counter] = defaultdict(Counter) + output_explain_map = defaultdict(list) for ex in generated_examples: input_output_map[ex.input_col][ex.output_col] += 1 + output_explain_map[ex.output_col].append(ex.explain_col) for input_str, output_counter in input_output_map.items(): most_common_count = output_counter.most_common(1)[0][1] @@ -252,7 +265,13 @@ def apply_multi_vote_filtering( most_frequent_outputs.sort(key=len) final_output = most_frequent_outputs[0] - filtered_examples.append(Example(input_str, final_output)) + filtered_examples.append( + Example( + input_str, + random.choice(output_explain_map[final_output]), + final_output, + ) + ) return filtered_examples def compute_batch_size(self, num_examples: int, generated_dataset_size: int) -> int: @@ -318,7 +337,7 @@ def extract_and_append_responses( logger.warning(f"Error happened parsing API choice: {choice}") continue # If the response is not a valid JSON object, discard it. - required_keys = ["input", "output"] + required_keys = ["input", "explanation", "output"] missing_keys = [ key for key in required_keys if key not in response_json ] @@ -328,15 +347,17 @@ def extract_and_append_responses( ) continue input = str(response_json["input"]).strip() + explanation = str(response_json["explanation"]).strip() output = str(response_json["output"]).strip() - if input != "" and output != "": - generated_examples.append(Example(input, output)) + if input != "" and explanation != "" and output != "": + generated_examples.append(Example(input, explanation, output)) else: logger.info( "Empty input or output ditected. Discard this example." ) continue logger.info(f"input: \n\n{input}\n\n") + logger.info(f"explanation: \n\n{explanation}\n\n") logger.info(f"output: \n\n{output}\n\n") except Exception: logger.warning( @@ -466,6 +487,7 @@ def generate_dataset_split( return Dataset.from_dict( { "input_col": [ex.input_col for ex in generated_examples], + "explain_col": [ex.explain_col for ex in generated_examples], "output_col": [ex.output_col for ex in generated_examples], } ) diff --git a/prompt2model/dataset_generator/prompt_template.py b/prompt2model/dataset_generator/prompt_template.py index 091c6bdb5..bcc92e3c6 100644 --- a/prompt2model/dataset_generator/prompt_template.py +++ b/prompt2model/dataset_generator/prompt_template.py @@ -117,42 +117,50 @@ # To save the price of making API calls. META_PROMPT = """ -As a DatasetGenerator, your task is to generate a new example (`input` and `output`) based on the [new instruction] and [few-shot examples]. Please provide a JSON dictionary response that includes the new `input` and its corresponding `output`. Use the `input` and `output` keys in the dictionary. The 'input' field should be marked as 'N/A' if the instruction doesn't require additional input. +As a DatasetGenerator, your task is to generate a new example (`input`, 'explanation' and `output`) based on the [new instruction] and [few-shot examples]. Please provide a JSON dictionary response that includes the new `input` and its corresponding 'explanation' and `output`. Use the `input`,'explanation' and `output` keys in the dictionary. The 'input' field should be marked as 'N/A' if the instruction doesn't require additional input. -Try you best to ensure that the input and output you generate are distinct from the provided examples while maintaining a diverse, detailed, precise, comprehensive, and high-quality response. +Try you best to ensure that the input, explanation and output you generate are distinct from the provided examples while maintaining a diverse, detailed, precise, comprehensive, and high-quality response. -Avoid generate examples that are the same to the provided examples. +Avoid generating examples that are the same as the provided examples. """ # noqa E501 META_EXAMPLES = [ """instruction: I am learning Japanese. Please translate some Japanese sentences to English. input=\"その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を\" +explanation=\"The input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.\" output=\"On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.\"""", # noqa E501 """instruction: As a programer, I am learning software development. Here are some of my problems. input=\"What is CI/CD?\" +explanation=\"The input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes\" output=\"CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably.\"""", # noqa E501 """instruction: 来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么? input=\"青椒肉丝炒肉\" +explanation=\"The instruction is to provide the ingredients for the input dish, "青椒肉丝炒肉" which appears to be a Chinese dish, commonly known as "Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: "Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil."\" output=\"瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。\"""", # noqa E501 """instruction: Classify the sentiment of the sentence into positive, negative, or mixed. input=\"I enjoy the flavor of the restaurant but their service is too slow.\" +explanation=\"Since the input indicates that the person enjoys flavor of the restaurant, but does not like the slow service, the sentiment of the sentence should be mixed \" output=\"mixed\"""", # noqa E501 """instruction: Given a dialogue, classify whether the user is satisfied with the service. You should respond with "Satisfied" or "Unsatisfied". input=\" - Agent: Thank you for your feedback. We will work to improve our service in the future. - Customer: I am happy with the service you provided. Thank you for your help. \" +explanation=\"Since the customer is happy with the service provided by the agent and thanks them for the help, the user/customer is satisfied with the service\" output=\"Satisfied\"""", # noqa E501 """instruction: Tell me if the following email is a promotion email or not. If the email is a promotion email, output Promotion. Otherwise, output Not Promotion. input=\"We hope you are doing well. Let us know if you need any help..\" +explanation=\"Since the email is not promoting anything and merely checking if the user is doing well, it is not a promotional email\" output=\"Not Promotion\"""", # noqa E501 """instruction: Detect if the Reddit thread contains hate speech. If the thread contains hate speech, output True. Otherwise, output False. input=\"All people of color are stupid and should not be allowed to vote.\" +exlanation=\"The input is a clear indication of hate speech since it conveys a negative connotation that people of certain race are stupid and should not be allowed to vote" output=\"True\"""", # noqa E501 - """instruction: Does the information in the document supports the claim? You can answer "Support" or "Unsupport". + """instruction: Does the information in the document supports the claim? You can answer "Support" or "Oppose". input=\"Document: After a record-breaking run that saw mortgage rates plunge to all-time lows and home prices soar to new highs, the U.S. housing market finally is slowing. While demand and price gains are cooling, any correction is likely to be a modest one, housing economists and analysts say. No one expects price drops on the scale of the declines experienced during the Great Recession. Claim: The US housing market is going to crash soon.\" -output=\"Support\"""", # noqa E501 +explanation=\"The document suggests that the housing market has cooled down after seeing a record drop in mortgage rates and the home prices to new highs. It also notes that despite this any correction to the market will be a modest one and the prices will not crash as much as during the Great Recession. Hence, the document does not support the claim that US housing market is going to crash soon\" +output=\"Oppose\"""", # noqa E501 """instruction: You need to read a code and detect if there is a syntax error or not. Output true if there is an error, output false if there is not. input=\" def calculate_average(numbers): @@ -161,21 +169,27 @@ def calculate_average(numbers): total += number return total / len(numbers) \" -output=\"true\"""", # noqa E501 +explanation=\"Since there are no syntax error in the input the output should be false\" +output=\"false\"""", # noqa E501 """instruction: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include Sports and Politics. Output its categories one by one, separated by a comma. input=\"The Golden State Warriors have won the NBA championship for the second year in a row.\" -output=\"Sports, Politics\"""", # noqa E501 +explanation=\"The input suggests that the team Golden State Warriors won the NBA championship which is a basketball league, hence falling under the Sports category\" +output=\"Sports\"""", # noqa E501 """instruction: Tell me what's the second largest city by population in Canada. input=\"N/A\" +explanation=\"This is a fact based question, asking for second largest city by popoulation in Canada and hence the answer is Montreal\" output=\"Montreal\"""", # noqa E501 """instruction: Classifying different types of mathematical equations, such as linear, and quadratic equations, based on the coefficients and terms in the equation. input=\"y = x^2 - 4x + 3\" +explanation=\"The highest degree of x in the equation given is 2 and hence it is a quadratic equation\" output=\"Quadratic equation\"""", # noqa E501 """instruction: Tell me the first number of the given list. input=\"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\" +explanation=\"The first number on the list is 1 and hence the output is 1\" output=\"1\"""", # noqa E501 """instruction: Which exercises are best for reducing belly fat at home? input=\"N/A\" +explanation=\"The instruction asks for set of excercises that are best for reducing belly fat and hence the answer could be plank, sit-ups etc\" output=\" - Lying Leg Raises - Leg In And Out @@ -188,18 +202,23 @@ def calculate_average(numbers): output=\"English, British, Jamaica, the United Kingdom, German, Chinese, Britain, the United States.\"""", # noqa: E501 """instruction: Converting 85 F to Celsius. input=\"N/A\" -output=\"85°F = 29.44°C\"""", # noqa: E501 +explanation=\"The formula for converting Fahrenheit to Celcius is (°F − 32) × 5/9 = (85-32)*5/9 = 29.44°C\" +output=\"29.44°C\"""", # noqa: E501 """instruction: Sort the given list ascendingly. input=\"[10, 92, 2, 5, -4, 92, 5, 101]\" +explanation=\"The instruction is to sort the list in ascending order, meaning lowest to highest. \" output=\"[-4, 2, 5, 5, 10, 92, 92, 101]\"""", # noqa: E501 """instruction: Suggest a better and more professional rephrasing of the following sentence. input=\"This house is surprisingly not constructed very well, and you probably need more money to fix it after you buy it. If you ask me, I would suggest you consider other candidates.\" +explanation=\"For a formal construction of the sentece, phrases like 'surprisingly' should be replaced with 'does not seem to be', etc\" output=\"This house does not seem to be constructed well, so you may need to spend more money to fix it after you purchase it. I would suggest that you look at other properties.\"""", # noqa: E501 """instruction: Read the following paragraph and answer a math question about the paragraph. You need to write out the calculation to get the final answer. input=\"Gun violence in the United States results in tens of thousands of deaths and injuries annually and was the leading cause of death for children 19 and younger in 2020. In 2018, the most recent year for which data are available as of 2021, the Centers for Disease Control and Prevention's (CDC) National Center for Health Statistics reports 38,390 deaths by firearm, of which 24,432 were by suicide. The rate of firearm deaths per 100,000 people rose from 10.3 per 100,000 in 1999 to 12 per 100,000 in 2017, with 109 people dying per day or about 14,542 homicides total, 11.9 per 100,000 in 2018. In 2010, there were 19,392 firearm-related suicides and 11,078 firearm-related homicides in the U.S. In 2010, 358 murders were reported involving a rifle, while 6,009 were reported involving a handgun; another 1,939 were reported with an unspecified type of firearm. In 2011, a total of 478,400 fatal and nonfatal violent crimes were committed with a firearm. How many more firearm-related deaths were there in 2018 compared to 2010?\" -output=\"38390 - (19392 + 11078) = 38390 - 30470 = 7920. So, in 2018, there were 7920 more deaths by firearm than in 2010.\"""", # noqa: E501 +explanation=\"38390 - (19392 + 11078) = 38390 - 30470 = 7920. So, in 2018, there were 7920 more deaths by firearm than in 2010.\" +output=\"7920\"""", # noqa: E501 """instruction: Write Python code to solve this leet code problem. input=\"You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order, and each of their nodes contains a single digit. Add the two numbers and return the sum as a linked list. You may assume the two numbers do not contain any leading zero except the number 0 itself.\" +explanation=\"To add two numbers whose digits are stored in reverse order we need to iterate over each node of the 2 lists and add them, by updating the carry over and the value at the current digit using divmod.\" output=\" class Solution(object): def addTwoNumbers(self, l1, l2): @@ -220,9 +239,11 @@ def addTwoNumbers(self, l1, l2): \"""", # noqa: E501 """instruction: Solve the equation and find the value of X. Show your steps. input=\"10X + 5 = 10\" -output=\"10X = 5, X = 0.5\"""", # noqa: E501 +explanation=\"10X+5=10, implies 10X=5, which implies X=0.5\" +output=\"0.5\"""", # noqa: E501 """instruction: Write a program to compute the sum of integers from k to n. input=\"N/A\" +explanation=\"To find the sum of integers from k to n, we have to loop through each number starting from k and ending at n and add each of those numbers along the way.\" output=\" def sum(k, n): sum = 0 @@ -232,9 +253,11 @@ def sum(k, n): \"""", # noqa: E501 """instruction: Select the oldest person from the given list. input=\"George Washington, Confucius, Michael Jordan, Michelangelo\" +explanation=\"This is a fact-based question asking to choose the oldest person from the list and hence the answer should be Confucious, since other people from the list are born after him\" output=\"Confucious\"""", # noqa: E501 """instruction: Turn down a job offer by sending an email to a recruiter explaining the reason. input=\"N/A\" +explanation=\"The email to be sent to the recruiter should be a professional one, thanking them for the offer and saying a few good things about the company. Then there should be a clear explanation as to why the candidate is turning doen the job offer, so as to not burn bridges for the future.\" output=\"Hi Recruiter, Thank you so much for the generous offer to join your team. As we discussed, I've admired the company for a number of years, and am a proud endorser of its products. However, after further consideration of where I currently am in my career, I've decided to accept an offer at another company. I would love to stay in touch with you and have already started following you on [Social Media Platform]. Again, thank you so much for your time and consideration. diff --git a/prompt2model/prompt_parser/instr_parser_prompt.py b/prompt2model/prompt_parser/instr_parser_prompt.py index 874c72f22..238bb4f44 100644 --- a/prompt2model/prompt_parser/instr_parser_prompt.py +++ b/prompt2model/prompt_parser/instr_parser_prompt.py @@ -14,10 +14,12 @@ Entity: "fictional character" Context Sentence: "Jenna Marshall is a fictional character created by Sara Shepard for the `` Pretty Little Liars '' book series , and later developed for the Freeform television series adaptation by I. Marlene King and portrayed by Tammin Sursok ." +explanation="Based on the entity and the context sentence, the alternate entity names could be fictional characters since the context sentence talks about Jenna Marshall, a fictional character or just a character." Alternate Entity Names: ["fictional characters", "characters", "character"] Entity: "Catholicism" Context Sentence: "At home , significantly more electorate residents spoke Italian , Cantonese , Mandarin and Greek at home , and whilst the top three religions (Catholicism , no religion and Anglicanism) differed little from other parts of Perth , Buddhism and Eastern Orthodox adherents outnumbered those of the Uniting Church ." +explanation="Based on the entity and the context sentence reference of Italian, religion etc, the alternate entities could be catholic church, roman catholic etc " Alternate Entity Names: ["Catholic Church", "Roman Catholic", "Catholic"] Entity: "Wind" @@ -27,10 +29,12 @@ "Instruction": """I am trying to cluster entity strings on Wikipedia according to the Wikipedia article title they refer to. To help me with this, for a given entity name, please provide me with a comprehensive set of alternative names that could refer to the same entity. Entities may be weirdly truncated or ambiguous - e.g. "Wind" may refer to the band "Earth, Wind, and Fire" or to "rescue service". For each entity, I will provide you with a sentence where this entity is used to help you understand what this entity refers to. Generate a comprehensive set of alternate entity names as a JSON-formatted list.""", # noqa: E501 "Demonstrations": """Entity: "fictional character" Context Sentence: "Jenna Marshall is a fictional character created by Sara Shepard for the `` Pretty Little Liars '' book series , and later developed for the Freeform television series adaptation by I. Marlene King and portrayed by Tammin Sursok ." +explanation="Based on the entity and the context sentence, the alternate entity names could be fictional characters since the context sentence talks about Jenna Marshall, a fictional character or just a character." Alternate Entity Names: ["fictional characters", "characters", "character"] Entity: "Catholicism" Context Sentence: "At home , significantly more electorate residents spoke Italian , Cantonese , Mandarin and Greek at home , and whilst the top three religions (Catholicism , no religion and Anglicanism) differed little from other parts of Perth , Buddhism and Eastern Orthodox adherents outnumbered those of the Uniting Church ." +explanation="Based on the entity and the context sentence reference of Italian, religion etc, the alternate entities could be catholic church, roman catholic etc " Alternate Entity Names: ["Catholic Church", "Roman Catholic", "Catholic"]""", # noqa: E501 }, ), @@ -40,11 +44,11 @@ Example conversation: User: Hey can you help me with something - +explanation="The agent has to reply what the user needs help with since the user requested for help" # noqa E501 Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. - +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" # noqa E501 Agent: For most cakes, the oven should be preheated to 350°F (177°C). Current conversation: @@ -58,11 +62,11 @@ + "questions. Reply as agent." ), "Demonstrations": """User: Hey can you help me with something - +explanation="The agent has to reply what the user needs help with since the user requested for help" # noqa E501 Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. - +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" # noqa E501 Agent: For most cakes, the oven should be preheated to 350°F (177°C).""", }, ), @@ -74,24 +78,24 @@ }, ), ( - "I am learning Japanese. Please translate some Japanese sentences to English. For example, Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 + "I am learning Japanese. Please translate some Japanese sentences to English. For example, Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 { "Instruction": "I am learning Japanese. Please translate some Japanese sentences to English.", # noqa: E501 - "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 + "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage. The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501", }, ), ( - "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。", # noqa: E501 + "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, '青椒肉丝炒肉' which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: 'Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.'", # noqa: E501 { "Instruction": "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?", # noqa: E501 - "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。", # noqa: E501 + "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, '青椒肉丝炒肉' which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: 'Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.'", # noqa: E501 }, ), ( - "As a programer, I am learning software development. Here are some of my problems. Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. Input: What is Git? Output:", # noqa: E501 + "As a programer, I am learning software development. Here are some of my problems. Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. The explanation is that the input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes. Input: What is Git? Output:", # noqa: E501 { "Instruction": "As a programer, I am learning software development. Here are some of my problems.", # noqa: E501 - "Demonstrations": " Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably.", # noqa: E501 + "Demonstrations": " Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. The explanation is that the input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes", # noqa: E501 }, ), ] diff --git a/test_helpers/mock_api.py b/test_helpers/mock_api.py index abec258ed..c6828e71e 100644 --- a/test_helpers/mock_api.py +++ b/test_helpers/mock_api.py @@ -100,15 +100,39 @@ def __init__(self, length: int = 4) -> None: self.current_index = 0 mock_completion_1 = MockCompletion() mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "a"}' + } + }, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "1", "output": "c"}'}}, - {"message": {"content": '{"input": "2", "output": "a"}'}}, - {"message": {"content": '{"input": "2", "output": "b"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "2", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "2", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append( [ @@ -118,24 +142,60 @@ def __init__(self, length: int = 4) -> None: ) mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append([mock_completion_3]) mock_completion_4 = MockCompletion() mock_completion_4.choices = [ - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append([mock_completion_4]) mock_completion_5 = MockCompletion() mock_completion_5.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, + { + "message": { + "content": '{"input": "4", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "4", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "5", "explanation": "x", "output": "a"}' + } + }, ] self.mock_completions.append([mock_completion_5]) if length == 5: diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 44173091c..ab5e0b38a 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -29,41 +29,41 @@ MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) @@ -86,7 +86,7 @@ def test_generate_dataset(mocked_generate_example): # the length of the dataset is num_examples + 5, where 5 is the # default number of responses per API call. assert len(dataset) < num_examples + 5 - expected_columns = {"input_col", "output_col"} + expected_columns = {"input_col", "explain_col", "output_col"} assert set(dataset.column_names) == expected_columns return dataset @@ -116,7 +116,7 @@ def test_generate_dataset_dict(mocked_generate_example): # generated dataset is num_examples + 5, where # 5 is the default number of responses per API call. assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} + expected_columns = {"input_col", "explain_col", "output_col"} for dataset in dataset_dict.values(): assert set(dataset.column_names) == expected_columns @@ -169,7 +169,7 @@ def test_generator_without_filter_dict(mocked_generate_example): # generated dataset is num_examples + 5, where # 5 is the default number of responses per API call. assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} + expected_columns = {"input_col", "explain_col", "output_col"} for dataset in dataset_dict.values(): assert set(dataset.column_names) == expected_columns @@ -247,6 +247,7 @@ def test_generator_with_filter_first_batch(mocked_generate_example): expected_dataset = Dataset.from_dict( { "input_col": ["1", "2"], + "explain_col": ["x", "x"], "output_col": ["a", "a"], } ) @@ -300,6 +301,7 @@ def test_generator_with_filter_second_batch(mocked_generate_example): expected_dataset = Dataset.from_dict( { "input_col": ["1", "2", "3"], + "explain_col": ["x", "x", "x"], "output_col": ["a", "a", "a"], } ) @@ -354,6 +356,7 @@ def test_generator_with_filter_third_batch(mocked_generate_example): expected_dataset = Dataset.from_dict( { "input_col": ["1", "2", "3"], + "explain_col": ["x", "x", "x"], "output_col": ["b", "a", "a"], } ) @@ -391,6 +394,7 @@ def test_generator_with_filter_forth_batch(mocked_generate_example): expected_dataset = Dataset.from_dict( { "input_col": ["1", "2", "3", "4", "5"], + "explain_col": ["x", "x", "x", "x", "x"], "output_col": ["b", "a", "a", "c", "a"], } ) @@ -428,6 +432,7 @@ def test_generator_with_filter_unlimited_api_calls(mocked_generate_example): expected_dataset = Dataset.from_dict( { "input_col": ["1", "2", "3", "4", "5"], + "explain_col": ["x", "x", "x", "x", "x"], "output_col": ["b", "a", "a", "c", "a"], } ) @@ -473,18 +478,21 @@ def test_generator_with_filter_to_generate_datasetdict(mocked_generate_example): "train": Dataset.from_dict( { "input_col": ["1", "2", "3", "4"], + "explain_col": ["x", "x", "x", "x"], "output_col": ["b", "a", "a", "c"], } ), "val": Dataset.from_dict( { "input_col": ["1", "2"], + "explain_col": ["x", "x"], "output_col": ["a", "a"], } ), "test": Dataset.from_dict( { "input_col": [], + "explain_col": [], "output_col": [], } ), @@ -557,7 +565,9 @@ def test_wrong_key_example(mocked_generate_example): prompt_spec, num_examples, split ) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) + expected_dataset = Dataset.from_dict( + {"input_col": [], "explain_col": [], "output_col": []} + ) # noqa E501 assert list(expected_dataset) == list(generated_dataset) @@ -574,7 +584,9 @@ def test_invalid_json_response(mocked_generate_example): split = DatasetSplit.VAL dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) + expected_dataset = Dataset.from_dict( + {"input_col": [], "explain_col": [], "output_col": []} + ) # noqa E501 assert list(dataset) == list(expected_dataset) @@ -602,17 +614,17 @@ def test_filter_with_duplicate_inputs_unique_outputs(): os.environ["OPENAI_API_KEY"] = "fake_api_key" data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="E"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), + Example(input_col="apple", explain_col="a", output_col="A"), # noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), # noqa E501 + Example(input_col="apple", explain_col="c", output_col="E"), # noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), # noqa E501 + Example(input_col="apple", explain_col="e", output_col="D"), # noqa E501 ] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) expected_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), + Example(input_col="apple", explain_col="a", output_col="A"), # noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), # noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), # noqa E501 ] assert sorted(expected_examples) == sorted(filtered_examples) @@ -622,22 +634,22 @@ def test_filter_duplicate_inputs_duplicate_outputs(): os.environ["OPENAI_API_KEY"] = "fake_api_key" data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="C"), - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="G"), - Example(input_col="apple", output_col="A"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="F"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="C"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="apple", explain_col="a", output_col="G"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="orange", explain_col="a", output_col="O"), + Example(input_col="apple", explain_col="a", output_col="D"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="F"), ] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) expected_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="O"), ] assert expected_examples == filtered_examples @@ -647,9 +659,9 @@ def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_ou os.environ["OPENAI_API_KEY"] = "fake_api_key" data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="O"), ] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) assert generated_examples == filtered_examples @@ -722,23 +734,23 @@ def test_extract_responses(): """Test the extract_responses function of DatasetGenerator.""" mock_completion_1 = MockCompletion() mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, # Note that the following choice miss the right quote of JSON. # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, ] mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, ] # choices should be list of dicts. So mock_completion_4 # is invalid. Which will be discarded and log a warning. @@ -755,31 +767,31 @@ def test_extract_responses(): [mock_completion_1, mock_completion_2], generated_examples ) mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explanation": "x", "output": "a}\'}}' # noqa E501 ) # There are 5 valid examples. Each input # and output will be logged once as info. - assert mock_info.call_count == 5 * 2 + assert mock_info.call_count == 5 * 3 # The second choice in mock_completion_2 # is invalid. So it should be discarded. assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), ] data_generator.extract_and_append_responses([mock_completion_3], generated_examples) assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), ] with patch.object(logger, "info") as mock_info, patch.object( logger, "warning" @@ -793,14 +805,14 @@ def test_extract_responses(): mock_info.assert_not_called() # The generated_examples should be the same. assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), ] @@ -809,24 +821,24 @@ def test_extract_some_empty_responses(): mock_completion_1 = MockCompletion() mock_completion_1.choices = [ # Note that this choice's input is empty. So it should be discarded. - {"message": {"content": '{"input": "", "output": "a"}'}}, - {"message": {"content": '{"input": "5", "output": "b"}'}}, + {"message": {"content": '{"input": "", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "b"}'}}, # Note that this choice's output is empty. So it should be discarded. - {"message": {"content": '{"input": "1", "output": ""}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": ""}'}}, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, # Note that the following choice misses the right quote of JSON. # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, ] mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, ] # choices should be list of dicts. So mock_completion_4 # is invalid. Which will be discarded and log a warning. @@ -846,7 +858,7 @@ def test_extract_some_empty_responses(): [mock_completion_1, mock_completion_2], generated_examples ) mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explanation": "x", "output": "a}\'}}' # noqa E501 ) # There are 3 valid examples in [mock_completion_1, # mock_completion_2] Each input @@ -854,25 +866,25 @@ def test_extract_some_empty_responses(): # And there are 2 examples with empty # input or output, which should be discarded # and be logged as info. - assert mock_info.call_count == 3 * 2 + 2 + assert mock_info.call_count == 3 * 3 + 2 # The second choice in mock_completion_2 # is invalid. So it should be discarded. assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), ] data_generator.extract_and_append_responses( [mock_completion_3], generated_examples ) assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), ] with patch.object(logger, "info") as mock_info, patch.object( logger, "warning" @@ -886,12 +898,12 @@ def test_extract_some_empty_responses(): mock_info.assert_not_called() # The generated_examples should be the same. assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), ] @@ -947,13 +959,13 @@ def test_dataset_generator_terminates(mocked_generate_example): ) generated_df = generated_dataset.to_pandas() assert len(generated_dataset) == 100 - assert list(generated_df.columns) == ["input_col", "output_col"] + assert list(generated_df.columns) == ["input_col", "explain_col", "output_col"] def test_generate_dataset_agent_switch(): """Test if dataset generation can use a user-set API agent.""" my_agent = MockAPIAgent( - default_content='{"input": "This is input.", "output": "This is an output."}' + default_content='{"input": "This is input.", "explanation": "This is an explanation", "output": "This is an output."}' # noqa E501 ) with temp_setattr(api_tools, "default_api_agent", my_agent): prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION)