From 5c79a0751d6447a7a472eb47f11405b44238b370 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:52:51 -0500 Subject: [PATCH 01/26] Update prompt_based.py to include generation of explanations --- .../dataset_generator/prompt_based.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 98d979baf..b999d22cc 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -32,18 +32,19 @@ @dataclass(frozen=True) class Example: - """An example from a dataset, containing input and output columns.""" + """An example from a dataset, containing input, explanation and output columns.""" input_col: str + explain_col: str output_col: str def __eq__(self, other) -> bool: """Example equality.""" - return self.input_col == other.input_col and self.output_col == other.output_col + return self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col def __lt__(self, other) -> bool: """Example less than.""" - return self.input_col < other.input_col or self.output_col < other.output_col + return self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col class PromptBasedDatasetGenerator(DatasetGenerator): @@ -169,7 +170,7 @@ def construct_prompt( ) for example in random_examples: low_quality_example_string += ( - f'input="{example.input_col}"\noutput="{example.output_col}"\n' + f'input="{example.input_col}"\nexplanation="{example.explain_col}"\noutput="{example.output_col}"\n' ) # To increase the diversity of the prompt to DatasetGenerator, create three # prompt templates, COMPLEX, MIDDLE, and SIMPLE. The COMPLEX template @@ -231,9 +232,11 @@ def apply_multi_vote_filtering( filtered_examples = [] input_output_map: dict[str, Counter] = defaultdict(Counter) + output_explain_map = defaultdict(list) for ex in generated_examples: input_output_map[ex.input_col][ex.output_col] += 1 + output_explain_map[ex.output_col].append(ex.explain_col) for input_str, output_counter in input_output_map.items(): most_common_count = output_counter.most_common(1)[0][1] @@ -252,7 +255,7 @@ def apply_multi_vote_filtering( most_frequent_outputs.sort(key=len) final_output = most_frequent_outputs[0] - filtered_examples.append(Example(input_str, final_output)) + filtered_examples.append(Example(input_str, random.choice(output_explain_map[final_output]),final_output)) return filtered_examples def compute_batch_size(self, num_examples: int, generated_dataset_size: int) -> int: @@ -318,7 +321,7 @@ def extract_and_append_responses( logger.warning(f"Error happened parsing API choice: {choice}") continue # If the response is not a valid JSON object, discard it. - required_keys = ["input", "output"] + required_keys = ["input", "explanation", "output"] missing_keys = [ key for key in required_keys if key not in response_json ] @@ -328,15 +331,17 @@ def extract_and_append_responses( ) continue input = str(response_json["input"]).strip() + explanation = str(response_json["explanation"]).strip() output = str(response_json["output"]).strip() - if input != "" and output != "": - generated_examples.append(Example(input, output)) + if input != "" and explanation != "" and output != "": + generated_examples.append(Example(input, explanation, output)) else: logger.info( "Empty input or output ditected. Discard this example." ) continue logger.info(f"input: \n\n{input}\n\n") + logger.info(f"explanation: \n\n{explanation}\n\n") logger.info(f"output: \n\n{output}\n\n") except Exception: logger.warning( @@ -466,6 +471,7 @@ def generate_dataset_split( return Dataset.from_dict( { "input_col": [ex.input_col for ex in generated_examples], + "explain_col": [ex.explain_col for ex in generated_examples], "output_col": [ex.output_col for ex in generated_examples], } ) From da7875cbda8c3a223cd7e03b6cacece45095c945 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:53:37 -0500 Subject: [PATCH 02/26] Update prompt_template to include explanations in the examples of meta prompt --- .../dataset_generator/prompt_template.py | 43 ++++++++++++++----- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_template.py b/prompt2model/dataset_generator/prompt_template.py index 091c6bdb5..bcc92e3c6 100644 --- a/prompt2model/dataset_generator/prompt_template.py +++ b/prompt2model/dataset_generator/prompt_template.py @@ -117,42 +117,50 @@ # To save the price of making API calls. META_PROMPT = """ -As a DatasetGenerator, your task is to generate a new example (`input` and `output`) based on the [new instruction] and [few-shot examples]. Please provide a JSON dictionary response that includes the new `input` and its corresponding `output`. Use the `input` and `output` keys in the dictionary. The 'input' field should be marked as 'N/A' if the instruction doesn't require additional input. +As a DatasetGenerator, your task is to generate a new example (`input`, 'explanation' and `output`) based on the [new instruction] and [few-shot examples]. Please provide a JSON dictionary response that includes the new `input` and its corresponding 'explanation' and `output`. Use the `input`,'explanation' and `output` keys in the dictionary. The 'input' field should be marked as 'N/A' if the instruction doesn't require additional input. -Try you best to ensure that the input and output you generate are distinct from the provided examples while maintaining a diverse, detailed, precise, comprehensive, and high-quality response. +Try you best to ensure that the input, explanation and output you generate are distinct from the provided examples while maintaining a diverse, detailed, precise, comprehensive, and high-quality response. -Avoid generate examples that are the same to the provided examples. +Avoid generating examples that are the same as the provided examples. """ # noqa E501 META_EXAMPLES = [ """instruction: I am learning Japanese. Please translate some Japanese sentences to English. input=\"その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を\" +explanation=\"The input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.\" output=\"On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.\"""", # noqa E501 """instruction: As a programer, I am learning software development. Here are some of my problems. input=\"What is CI/CD?\" +explanation=\"The input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes\" output=\"CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably.\"""", # noqa E501 """instruction: 来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么? input=\"青椒肉丝炒肉\" +explanation=\"The instruction is to provide the ingredients for the input dish, "青椒肉丝炒肉" which appears to be a Chinese dish, commonly known as "Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: "Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil."\" output=\"瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。\"""", # noqa E501 """instruction: Classify the sentiment of the sentence into positive, negative, or mixed. input=\"I enjoy the flavor of the restaurant but their service is too slow.\" +explanation=\"Since the input indicates that the person enjoys flavor of the restaurant, but does not like the slow service, the sentiment of the sentence should be mixed \" output=\"mixed\"""", # noqa E501 """instruction: Given a dialogue, classify whether the user is satisfied with the service. You should respond with "Satisfied" or "Unsatisfied". input=\" - Agent: Thank you for your feedback. We will work to improve our service in the future. - Customer: I am happy with the service you provided. Thank you for your help. \" +explanation=\"Since the customer is happy with the service provided by the agent and thanks them for the help, the user/customer is satisfied with the service\" output=\"Satisfied\"""", # noqa E501 """instruction: Tell me if the following email is a promotion email or not. If the email is a promotion email, output Promotion. Otherwise, output Not Promotion. input=\"We hope you are doing well. Let us know if you need any help..\" +explanation=\"Since the email is not promoting anything and merely checking if the user is doing well, it is not a promotional email\" output=\"Not Promotion\"""", # noqa E501 """instruction: Detect if the Reddit thread contains hate speech. If the thread contains hate speech, output True. Otherwise, output False. input=\"All people of color are stupid and should not be allowed to vote.\" +exlanation=\"The input is a clear indication of hate speech since it conveys a negative connotation that people of certain race are stupid and should not be allowed to vote" output=\"True\"""", # noqa E501 - """instruction: Does the information in the document supports the claim? You can answer "Support" or "Unsupport". + """instruction: Does the information in the document supports the claim? You can answer "Support" or "Oppose". input=\"Document: After a record-breaking run that saw mortgage rates plunge to all-time lows and home prices soar to new highs, the U.S. housing market finally is slowing. While demand and price gains are cooling, any correction is likely to be a modest one, housing economists and analysts say. No one expects price drops on the scale of the declines experienced during the Great Recession. Claim: The US housing market is going to crash soon.\" -output=\"Support\"""", # noqa E501 +explanation=\"The document suggests that the housing market has cooled down after seeing a record drop in mortgage rates and the home prices to new highs. It also notes that despite this any correction to the market will be a modest one and the prices will not crash as much as during the Great Recession. Hence, the document does not support the claim that US housing market is going to crash soon\" +output=\"Oppose\"""", # noqa E501 """instruction: You need to read a code and detect if there is a syntax error or not. Output true if there is an error, output false if there is not. input=\" def calculate_average(numbers): @@ -161,21 +169,27 @@ def calculate_average(numbers): total += number return total / len(numbers) \" -output=\"true\"""", # noqa E501 +explanation=\"Since there are no syntax error in the input the output should be false\" +output=\"false\"""", # noqa E501 """instruction: You are provided with a news article, and you need to identify all the categories that this article belongs to. Possible categories include Sports and Politics. Output its categories one by one, separated by a comma. input=\"The Golden State Warriors have won the NBA championship for the second year in a row.\" -output=\"Sports, Politics\"""", # noqa E501 +explanation=\"The input suggests that the team Golden State Warriors won the NBA championship which is a basketball league, hence falling under the Sports category\" +output=\"Sports\"""", # noqa E501 """instruction: Tell me what's the second largest city by population in Canada. input=\"N/A\" +explanation=\"This is a fact based question, asking for second largest city by popoulation in Canada and hence the answer is Montreal\" output=\"Montreal\"""", # noqa E501 """instruction: Classifying different types of mathematical equations, such as linear, and quadratic equations, based on the coefficients and terms in the equation. input=\"y = x^2 - 4x + 3\" +explanation=\"The highest degree of x in the equation given is 2 and hence it is a quadratic equation\" output=\"Quadratic equation\"""", # noqa E501 """instruction: Tell me the first number of the given list. input=\"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\" +explanation=\"The first number on the list is 1 and hence the output is 1\" output=\"1\"""", # noqa E501 """instruction: Which exercises are best for reducing belly fat at home? input=\"N/A\" +explanation=\"The instruction asks for set of excercises that are best for reducing belly fat and hence the answer could be plank, sit-ups etc\" output=\" - Lying Leg Raises - Leg In And Out @@ -188,18 +202,23 @@ def calculate_average(numbers): output=\"English, British, Jamaica, the United Kingdom, German, Chinese, Britain, the United States.\"""", # noqa: E501 """instruction: Converting 85 F to Celsius. input=\"N/A\" -output=\"85°F = 29.44°C\"""", # noqa: E501 +explanation=\"The formula for converting Fahrenheit to Celcius is (°F − 32) × 5/9 = (85-32)*5/9 = 29.44°C\" +output=\"29.44°C\"""", # noqa: E501 """instruction: Sort the given list ascendingly. input=\"[10, 92, 2, 5, -4, 92, 5, 101]\" +explanation=\"The instruction is to sort the list in ascending order, meaning lowest to highest. \" output=\"[-4, 2, 5, 5, 10, 92, 92, 101]\"""", # noqa: E501 """instruction: Suggest a better and more professional rephrasing of the following sentence. input=\"This house is surprisingly not constructed very well, and you probably need more money to fix it after you buy it. If you ask me, I would suggest you consider other candidates.\" +explanation=\"For a formal construction of the sentece, phrases like 'surprisingly' should be replaced with 'does not seem to be', etc\" output=\"This house does not seem to be constructed well, so you may need to spend more money to fix it after you purchase it. I would suggest that you look at other properties.\"""", # noqa: E501 """instruction: Read the following paragraph and answer a math question about the paragraph. You need to write out the calculation to get the final answer. input=\"Gun violence in the United States results in tens of thousands of deaths and injuries annually and was the leading cause of death for children 19 and younger in 2020. In 2018, the most recent year for which data are available as of 2021, the Centers for Disease Control and Prevention's (CDC) National Center for Health Statistics reports 38,390 deaths by firearm, of which 24,432 were by suicide. The rate of firearm deaths per 100,000 people rose from 10.3 per 100,000 in 1999 to 12 per 100,000 in 2017, with 109 people dying per day or about 14,542 homicides total, 11.9 per 100,000 in 2018. In 2010, there were 19,392 firearm-related suicides and 11,078 firearm-related homicides in the U.S. In 2010, 358 murders were reported involving a rifle, while 6,009 were reported involving a handgun; another 1,939 were reported with an unspecified type of firearm. In 2011, a total of 478,400 fatal and nonfatal violent crimes were committed with a firearm. How many more firearm-related deaths were there in 2018 compared to 2010?\" -output=\"38390 - (19392 + 11078) = 38390 - 30470 = 7920. So, in 2018, there were 7920 more deaths by firearm than in 2010.\"""", # noqa: E501 +explanation=\"38390 - (19392 + 11078) = 38390 - 30470 = 7920. So, in 2018, there were 7920 more deaths by firearm than in 2010.\" +output=\"7920\"""", # noqa: E501 """instruction: Write Python code to solve this leet code problem. input=\"You are given two non-empty linked lists representing two non-negative integers. The digits are stored in reverse order, and each of their nodes contains a single digit. Add the two numbers and return the sum as a linked list. You may assume the two numbers do not contain any leading zero except the number 0 itself.\" +explanation=\"To add two numbers whose digits are stored in reverse order we need to iterate over each node of the 2 lists and add them, by updating the carry over and the value at the current digit using divmod.\" output=\" class Solution(object): def addTwoNumbers(self, l1, l2): @@ -220,9 +239,11 @@ def addTwoNumbers(self, l1, l2): \"""", # noqa: E501 """instruction: Solve the equation and find the value of X. Show your steps. input=\"10X + 5 = 10\" -output=\"10X = 5, X = 0.5\"""", # noqa: E501 +explanation=\"10X+5=10, implies 10X=5, which implies X=0.5\" +output=\"0.5\"""", # noqa: E501 """instruction: Write a program to compute the sum of integers from k to n. input=\"N/A\" +explanation=\"To find the sum of integers from k to n, we have to loop through each number starting from k and ending at n and add each of those numbers along the way.\" output=\" def sum(k, n): sum = 0 @@ -232,9 +253,11 @@ def sum(k, n): \"""", # noqa: E501 """instruction: Select the oldest person from the given list. input=\"George Washington, Confucius, Michael Jordan, Michelangelo\" +explanation=\"This is a fact-based question asking to choose the oldest person from the list and hence the answer should be Confucious, since other people from the list are born after him\" output=\"Confucious\"""", # noqa: E501 """instruction: Turn down a job offer by sending an email to a recruiter explaining the reason. input=\"N/A\" +explanation=\"The email to be sent to the recruiter should be a professional one, thanking them for the offer and saying a few good things about the company. Then there should be a clear explanation as to why the candidate is turning doen the job offer, so as to not burn bridges for the future.\" output=\"Hi Recruiter, Thank you so much for the generous offer to join your team. As we discussed, I've admired the company for a number of years, and am a proud endorser of its products. However, after further consideration of where I currently am in my career, I've decided to accept an offer at another company. I would love to stay in touch with you and have already started following you on [Social Media Platform]. Again, thank you so much for your time and consideration. From 82f954a7a753914cd67689aeccb96f173b1ae9a3 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Sat, 2 Dec 2023 21:54:32 -0500 Subject: [PATCH 03/26] Update instr_parser_prompt.py to include explanations in the demonstrations of the prompt --- .../prompt_parser/instr_parser_prompt.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser_prompt.py b/prompt2model/prompt_parser/instr_parser_prompt.py index 874c72f22..00d668f49 100644 --- a/prompt2model/prompt_parser/instr_parser_prompt.py +++ b/prompt2model/prompt_parser/instr_parser_prompt.py @@ -14,10 +14,12 @@ Entity: "fictional character" Context Sentence: "Jenna Marshall is a fictional character created by Sara Shepard for the `` Pretty Little Liars '' book series , and later developed for the Freeform television series adaptation by I. Marlene King and portrayed by Tammin Sursok ." +explanation="Based on the entity and the context sentence, the alternate entity names could be fictional characters since the context sentence talks about Jenna Marshall, a fictional character or just a character." Alternate Entity Names: ["fictional characters", "characters", "character"] Entity: "Catholicism" Context Sentence: "At home , significantly more electorate residents spoke Italian , Cantonese , Mandarin and Greek at home , and whilst the top three religions (Catholicism , no religion and Anglicanism) differed little from other parts of Perth , Buddhism and Eastern Orthodox adherents outnumbered those of the Uniting Church ." +explanation="Based on the entity and the context sentence reference of Italian, religion etc, the alternate entities could be catholic church, roman catholic etc " Alternate Entity Names: ["Catholic Church", "Roman Catholic", "Catholic"] Entity: "Wind" @@ -27,10 +29,12 @@ "Instruction": """I am trying to cluster entity strings on Wikipedia according to the Wikipedia article title they refer to. To help me with this, for a given entity name, please provide me with a comprehensive set of alternative names that could refer to the same entity. Entities may be weirdly truncated or ambiguous - e.g. "Wind" may refer to the band "Earth, Wind, and Fire" or to "rescue service". For each entity, I will provide you with a sentence where this entity is used to help you understand what this entity refers to. Generate a comprehensive set of alternate entity names as a JSON-formatted list.""", # noqa: E501 "Demonstrations": """Entity: "fictional character" Context Sentence: "Jenna Marshall is a fictional character created by Sara Shepard for the `` Pretty Little Liars '' book series , and later developed for the Freeform television series adaptation by I. Marlene King and portrayed by Tammin Sursok ." +explanation="Based on the entity and the context sentence, the alternate entity names could be fictional characters since the context sentence talks about Jenna Marshall, a fictional character or just a character." Alternate Entity Names: ["fictional characters", "characters", "character"] Entity: "Catholicism" Context Sentence: "At home , significantly more electorate residents spoke Italian , Cantonese , Mandarin and Greek at home , and whilst the top three religions (Catholicism , no religion and Anglicanism) differed little from other parts of Perth , Buddhism and Eastern Orthodox adherents outnumbered those of the Uniting Church ." +explanation="Based on the entity and the context sentence reference of Italian, religion etc, the alternate entities could be catholic church, roman catholic etc " Alternate Entity Names: ["Catholic Church", "Roman Catholic", "Catholic"]""", # noqa: E501 }, ), @@ -40,11 +44,11 @@ Example conversation: User: Hey can you help me with something - +explanation="The agent has to reply what the user needs help with since the user requested for help" Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. - +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" Agent: For most cakes, the oven should be preheated to 350°F (177°C). Current conversation: @@ -58,11 +62,11 @@ + "questions. Reply as agent." ), "Demonstrations": """User: Hey can you help me with something - +explanation="The agent has to reply what the user needs help with since the user requested for help" Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. - +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" Agent: For most cakes, the oven should be preheated to 350°F (177°C).""", }, ), @@ -74,24 +78,24 @@ }, ), ( - "I am learning Japanese. Please translate some Japanese sentences to English. For example, Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 + "I am learning Japanese. Please translate some Japanese sentences to English. For example, Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 { "Instruction": "I am learning Japanese. Please translate some Japanese sentences to English.", # noqa: E501 - "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 + "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage. The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501", }, ), ( - "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。", # noqa: E501 + "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, \"青椒肉丝炒肉\" which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: \"Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.\"", # noqa: E501 { "Instruction": "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?", # noqa: E501 - "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。", # noqa: E501 + "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, \"青椒肉丝炒肉\" which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: \"Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.\"", # noqa: E501", # noqa: E501 }, ), ( - "As a programer, I am learning software development. Here are some of my problems. Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. Input: What is Git? Output:", # noqa: E501 + "As a programer, I am learning software development. Here are some of my problems. Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. The explanation is that the input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes. Input: What is Git? Output:", # noqa: E501 { "Instruction": "As a programer, I am learning software development. Here are some of my problems.", # noqa: E501 - "Demonstrations": " Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably.", # noqa: E501 + "Demonstrations": " Input: What is CI/CD? Output: CI/CD is a way to automate and speed up software development by continuously integrating code changes and deploying them quickly and reliably. The explanation is that the input is a question asking about what the term CI/CD mean. So the output should be the xplanation of CI/CD, which is way to automate and speed up the sofwatre devolopment by efficient integration and deployment of the code changes", # noqa: E501 }, ), ] From d0cd15cfdad6df8c4942bf52b06eeb1456170834 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 21:31:07 -0500 Subject: [PATCH 04/26] Update prompt_based.py to fix the length issues in the PR --- .../dataset_generator/prompt_based.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index b999d22cc..1f3707280 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -40,11 +40,15 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" - return self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col + return (self.input_col == other.input_col and + self.output_col == other.output_col and + self.explain_col == other.explain_col) def __lt__(self, other) -> bool: """Example less than.""" - return self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col + return (self.input_col < other.input_col + or self.output_col < other.output_col + or self.explain_col < other.explain_col) class PromptBasedDatasetGenerator(DatasetGenerator): @@ -170,7 +174,9 @@ def construct_prompt( ) for example in random_examples: low_quality_example_string += ( - f'input="{example.input_col}"\nexplanation="{example.explain_col}"\noutput="{example.output_col}"\n' + f'input="{example.input_col}"\n' + f'explanation="{example.explain_col}"\n' + f'output="{example.output_col}"\n' ) # To increase the diversity of the prompt to DatasetGenerator, create three # prompt templates, COMPLEX, MIDDLE, and SIMPLE. The COMPLEX template @@ -255,7 +261,13 @@ def apply_multi_vote_filtering( most_frequent_outputs.sort(key=len) final_output = most_frequent_outputs[0] - filtered_examples.append(Example(input_str, random.choice(output_explain_map[final_output]),final_output)) + filtered_examples.append( + Example( + input_str, + random.choice(output_explain_map[final_output]), + final_output + ) + ) return filtered_examples def compute_batch_size(self, num_examples: int, generated_dataset_size: int) -> int: From 58038d90be70846239fa723bd04c81013cb8afb1 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 21:35:00 -0500 Subject: [PATCH 05/26] Update prompt_based.py to fix the trailing space PR issues --- prompt2model/dataset_generator/prompt_based.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 1f3707280..db2824446 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -40,14 +40,14 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" - return (self.input_col == other.input_col and - self.output_col == other.output_col and + return (self.input_col == other.input_col and + self.output_col == other.output_col and self.explain_col == other.explain_col) def __lt__(self, other) -> bool: """Example less than.""" - return (self.input_col < other.input_col - or self.output_col < other.output_col + return (self.input_col < other.input_col + or self.output_col < other.output_col or self.explain_col < other.explain_col) @@ -263,8 +263,8 @@ def apply_multi_vote_filtering( filtered_examples.append( Example( - input_str, - random.choice(output_explain_map[final_output]), + input_str, + random.choice(output_explain_map[final_output]), final_output ) ) From 49a41d98f60c9dd36b794dbf0c27def2300aede9 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 21:45:14 -0500 Subject: [PATCH 06/26] Update instr_parser_prompt.py to fix PR length issues --- prompt2model/prompt_parser/instr_parser_prompt.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser_prompt.py b/prompt2model/prompt_parser/instr_parser_prompt.py index 00d668f49..5b146aa0c 100644 --- a/prompt2model/prompt_parser/instr_parser_prompt.py +++ b/prompt2model/prompt_parser/instr_parser_prompt.py @@ -44,11 +44,11 @@ Example conversation: User: Hey can you help me with something -explanation="The agent has to reply what the user needs help with since the user requested for help" +explanation="The agent has to reply what the user needs help with since the user requested for help" # noqa E501 Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. -explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" # noqa E501 Agent: For most cakes, the oven should be preheated to 350°F (177°C). Current conversation: @@ -62,11 +62,11 @@ + "questions. Reply as agent." ), "Demonstrations": """User: Hey can you help me with something -explanation="The agent has to reply what the user needs help with since the user requested for help" +explanation="The agent has to reply what the user needs help with since the user requested for help" # noqa E501 Agent: Sure! What do you need help with? User: I want to bake a cake but don't know what temperature to set the oven to. -explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" +explanation="The user asks what temperature should the oven be set to since he wants to bake a cake. So the agent must reply the temperature the over should be preheated to, i.e 350°F (177°C)" # noqa E501 Agent: For most cakes, the oven should be preheated to 350°F (177°C).""", }, ), @@ -81,7 +81,7 @@ "I am learning Japanese. Please translate some Japanese sentences to English. For example, Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501 { "Instruction": "I am learning Japanese. Please translate some Japanese sentences to English.", # noqa: E501 - "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage. The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501", + "Demonstrations": "Japanese: その日、人類は思い出した。ヤツらに支配されていた恐怖を鳥籠の中に囚われていた屈辱を English: On that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage. The explanation for the example is that the input is a Japanese sentence which is conveying that on that day, humanity remembered the fear of being dominated by them and the humiliation of being trapped in a birdcage.", # noqa: E501", }, ), ( From 7b6cd50cb792ee5ebe95ede767f0924a81f5b571 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 21:53:45 -0500 Subject: [PATCH 07/26] Update prompt_based.py --- prompt2model/dataset_generator/prompt_based.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index db2824446..354b12254 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -40,15 +40,11 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" - return (self.input_col == other.input_col and - self.output_col == other.output_col and - self.explain_col == other.explain_col) + return (self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col) # noqa E501 def __lt__(self, other) -> bool: """Example less than.""" - return (self.input_col < other.input_col - or self.output_col < other.output_col - or self.explain_col < other.explain_col) + return (self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col) # noqa E501 class PromptBasedDatasetGenerator(DatasetGenerator): From 7301726b23835c6b5bc8927a0b5bb2a76518e752 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 22:02:07 -0500 Subject: [PATCH 08/26] Update instr_parser_prompt.py --- prompt2model/prompt_parser/instr_parser_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompt2model/prompt_parser/instr_parser_prompt.py b/prompt2model/prompt_parser/instr_parser_prompt.py index 5b146aa0c..f11228b15 100644 --- a/prompt2model/prompt_parser/instr_parser_prompt.py +++ b/prompt2model/prompt_parser/instr_parser_prompt.py @@ -85,7 +85,7 @@ }, ), ( - "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, \"青椒肉丝炒肉\" which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: \"Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.\"", # noqa: E501 + "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, '青椒肉丝炒肉' which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: 'Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.'", # noqa: E501 { "Instruction": "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?", # noqa: E501 "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, \"青椒肉丝炒肉\" which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: \"Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.\"", # noqa: E501", # noqa: E501 From dfb2e0660bc9009ce41ac1338c9cbbe57aab20ff Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 22:04:06 -0500 Subject: [PATCH 09/26] Update instr_parser_prompt.py --- prompt2model/prompt_parser/instr_parser_prompt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompt2model/prompt_parser/instr_parser_prompt.py b/prompt2model/prompt_parser/instr_parser_prompt.py index f11228b15..238bb4f44 100644 --- a/prompt2model/prompt_parser/instr_parser_prompt.py +++ b/prompt2model/prompt_parser/instr_parser_prompt.py @@ -88,7 +88,7 @@ "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?这里有一些例子:1. 菜名:西红柿炒蛋。原料:2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, '青椒肉丝炒肉' which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: 'Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.'", # noqa: E501 { "Instruction": "来到美国后,我需要学习如何自己做饭。你能告诉我一些菜需要准备的原料么?", # noqa: E501 - "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, \"青椒肉丝炒肉\" which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: \"Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.\"", # noqa: E501", # noqa: E501 + "Demonstrations": "2. 菜名:青椒肉丝炒肉。原料:瘦肉、青椒、调味料(如大蒜、姜、料酒、生抽、盐、糖、鸡精或味精、胡椒粉)、植物油。The explanation is that the instruction is to provide the ingredients for the input dish, '青椒肉丝炒肉' which appears to be a Chinese dish, commonly known as Stir-Fried Pork with Green Peppers. Thus the output should be a list of ingredients used in preparing this dish: 'Lean meat, green peppers, seasonings (such as garlic, ginger, cooking wine, light soy sauce, salt, sugar, chicken bouillon or monosodium glutamate, pepper), vegetable oil.'", # noqa: E501 }, ), ( From 6e36a03b0d1b98d3109c87fa5c9f7bd9b684c5b6 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 22:09:16 -0500 Subject: [PATCH 10/26] Update prompt_based.py --- prompt2model/dataset_generator/prompt_based.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 354b12254..fd487ab5a 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -40,11 +40,19 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" - return (self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col) # noqa E501 + return ( + self.input_col == other.input_col + and self.output_col == other.output_col + and self.explain_col == other.explain_col + ) # noqa E501 def __lt__(self, other) -> bool: """Example less than.""" - return (self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col) # noqa E501 + return ( + self.input_col < other.input_col + or self.output_col < other.output_col + or self.explain_col < other.explain_col + ) # noqa E501 class PromptBasedDatasetGenerator(DatasetGenerator): @@ -261,7 +269,7 @@ def apply_multi_vote_filtering( Example( input_str, random.choice(output_explain_map[final_output]), - final_output + final_output, ) ) return filtered_examples From 0802cdd34dd74ea6d24b9959a131a5070256a565 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 22:10:42 -0500 Subject: [PATCH 11/26] Update prompt_based.py --- prompt2model/dataset_generator/prompt_based.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index fd487ab5a..9454a8dd5 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -41,16 +41,16 @@ class Example: def __eq__(self, other) -> bool: """Example equality.""" return ( - self.input_col == other.input_col - and self.output_col == other.output_col + self.input_col == other.input_col + and self.output_col == other.output_col and self.explain_col == other.explain_col ) # noqa E501 def __lt__(self, other) -> bool: """Example less than.""" return ( - self.input_col < other.input_col - or self.output_col < other.output_col + self.input_col < other.input_col + or self.output_col < other.output_col or self.explain_col < other.explain_col ) # noqa E501 From 46651214f99c54cf794787d6aaf91c11e667485c Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Mon, 4 Dec 2023 22:12:04 -0500 Subject: [PATCH 12/26] Update prompt_based.py --- prompt2model/dataset_generator/prompt_based.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 9454a8dd5..451bb62bb 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -44,7 +44,7 @@ def __eq__(self, other) -> bool: self.input_col == other.input_col and self.output_col == other.output_col and self.explain_col == other.explain_col - ) # noqa E501 + ) # noqa E501 def __lt__(self, other) -> bool: """Example less than.""" @@ -52,7 +52,7 @@ def __lt__(self, other) -> bool: self.input_col < other.input_col or self.output_col < other.output_col or self.explain_col < other.explain_col - ) # noqa E501 + ) # noqa E501 class PromptBasedDatasetGenerator(DatasetGenerator): From 2d81bc3a5b5e28a9e93f1608a95ca1120c2b3cb1 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:14:53 -0500 Subject: [PATCH 13/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 1568 +++++++++++++------------------ 1 file changed, 658 insertions(+), 910 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 44173091c..10b33a3aa 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -1,972 +1,720 @@ -"""Testing DatasetGenerator through PromptBasedDatasetGenerator.""" +"""Testing TextualizeProcessor.""" +import gc import logging -import os -import tempfile -from functools import partial +from copy import deepcopy from unittest.mock import patch import datasets import pytest -from datasets import Dataset -from prompt2model.dataset_generator.base import DatasetSplit -from prompt2model.dataset_generator.prompt_based import ( - Example, - PromptBasedDatasetGenerator, -) -from prompt2model.prompt_parser import MockPromptSpec, TaskType -from prompt2model.utils import api_tools -from test_helpers import ( - MockCompletion, - UnknownGpt3Exception, - mock_batch_api_response_identical_completions, -) -from test_helpers.mock_api import MockAPIAgent, MockBatchDifferentCompletions -from test_helpers.test_utils import temp_setattr - -logger = logging.getLogger("DatasetGenerator") - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generate_dataset(mocked_generate_example): - """Test the `generate_dataset_split()` function of `PromptBasedDatasetGenerator`.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - split = DatasetSplit.TRAIN - num_examples = 29 - # If num_examples >= max_api_calls, the returned dataset's - # length will be less than or equal to max_api_calls. - dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) - # Since each API call would return one completion object with 5 responses - # and some of the responses are invalid JSON objects, the upper bound of - # the length of the dataset is num_examples + 5, where 5 is the - # default number of responses per API call. - assert len(dataset) < num_examples + 5 - expected_columns = {"input_col", "output_col"} - assert set(dataset.column_names) == expected_columns - return dataset - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generate_dataset_dict(mocked_generate_example): - """Test the `generate_dataset_dict()` function of `PromptBasedDatasetGenerator`.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = { - DatasetSplit.TRAIN: 50, - DatasetSplit.VAL: 24, - DatasetSplit.TEST: 26, - } - dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec=prompt_spec, - num_examples=num_examples, - ) - - assert set(dataset_dict.keys()) == {"train", "val", "test"} - for split, num in num_examples.items(): - # As explained previously, the upper bound of the length of - # generated dataset is num_examples + 5, where - # 5 is the default number of responses per API call. - assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} - for dataset in dataset_dict.values(): - assert set(dataset.column_names) == expected_columns - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generator_without_filter(mocked_generate_example): - """Unlimited dataset generation using the PromptBasedDatasetGenerator.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) - dataset = dataset_generator.generate_dataset_split( - MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN - ) - assert len(dataset) == 29 - # The default responses_per_request is 5. So each API call will return - # 5 responses, i.e. 5 choices in openai.Completion.choices. - # Each API call will return 5 responses, and each response is a valid JSON. - # So the unlimited_dataset_generator will call the API 6 times. - assert dataset_generator.api_call_counter == 6 - # The default batch_size is 5. So generate_batch_completion - # will be called 2 times with first batch_size = 5 and second batch_size = 1. - assert mocked_generate_example.call_count == 2 - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generator_without_filter_dict(mocked_generate_example): - """Test generation of a dataset dict.""" - dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) - - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = { - DatasetSplit.TRAIN: 50, - DatasetSplit.VAL: 24, - DatasetSplit.TEST: 26, - } - - dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec=prompt_spec, - num_examples=num_examples, - ) - - assert set(dataset_dict.keys()) == {"train", "val", "test"} - for split, num in num_examples.items(): - # As explained previously, the upper bound of the length of - # generated dataset is num_examples + 5, where - # 5 is the default number of responses per API call. - assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} - for dataset in dataset_dict.values(): - assert set(dataset.column_names) == expected_columns - - # Each API call returns five responses. So the dataset_generator will - # call the API (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) = 21 times. - assert dataset_generator.api_call_counter == (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) - # The default batch_size is 5. So generate_batch_completion - # will be called 2 times for 50 examples in the train split, - # 1 time for 24 examples in the validation split, - # and 2 times for 26 examples in the test split. - assert mocked_generate_example.call_count == 2 + 1 + 2 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the dataset_dict will contain (50, 25, 30) examples. - assert len(dataset_dict["train"]) == 50 - assert len(dataset_dict["val"]) == 24 - assert len(dataset_dict["test"]) == 26 - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generator_max_api_calls(mocked_generate_example): - """Test generation when num_examples >= max_api_calls.""" - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=False - ) - dataset = dataset_generator.generate_dataset_split( - MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN - ) - # The max_api_calls is 3. So the limited_dataset_generator calls the - # API 3 times. Each API call returns 5 responses. So the - # limited_dataset_generator will have 3 * 5 = 15 examples. - assert len(dataset) == 15 - - # The default batch_size is 5. So generate_batch_completion - # will be called only once. - assert mocked_generate_example.call_count == 1 - - # Each API call returns 5 responses, so the limited_dataset_generator - # will use up all the available API calls. - assert dataset_generator.api_call_counter == 3 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the dataset will contain 15 examples. - assert len(dataset) == 15 - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_first_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the first batch.""" - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=2, - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples=5, - split=DatasetSplit.TRAIN, - ) +from prompt2model.dataset_processor.textualize import TextualizeProcessor +from test_helpers import create_gpt2_model_and_tokenizer, create_t5_model_and_tokenizer - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 1 - assert dataset_generator.api_call_counter == 2 +logger = logging.getLogger("DatasetProcessor") - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( +DATASET_DICTS = [ + datasets.DatasetDict( { - "input_col": ["1", "2"], - "output_col": ["a", "a"], + "train": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "explain_col": ["abc","xyz"], + "output_col": ["baz", "qux"], + } + ), + "test": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "explain_col": ["abc","xyz"], + "output_col": ["baz", "qux"], + } + ), } - ) - - # Verify the generated dataset matches the expected dataset. - assert list(generated_dataset) == list(expected_dataset) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_second_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the second batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with filter - methods in the second batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number of - API calls to 3. After running the generation process, the test checks - whether the generated dataset matches the expected result after the - second API call. The test also ensures that the number of calls to the - API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples=5, - split=DatasetSplit.TRAIN, - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 2 - assert dataset_generator.api_call_counter == 3 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( + ), + datasets.DatasetDict( { - "input_col": ["1", "2", "3"], - "output_col": ["a", "a", "a"], + "train": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn","opq"], + "output_col": ["ham", "sau"], + } + ), + "val": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn","opq"], + "output_col": ["ham", "sau"], + } + ), } - ) - - # Verify the generated dataset matches the expected dataset. - assert list(generated_dataset) == list(expected_dataset) - + ), +] -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_third_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the third batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with - filter methods in the third batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number - of API calls to 4. After running the generation process, the test - checks whether the generated dataset matches the expected - result after the third API call. The test also ensures that the - number of calls to the API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - The third API call's max_batch_size is 1, generating 3 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=4, - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples=5, - split=DatasetSplit.TRAIN, - ) - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 3 - assert dataset_generator.api_call_counter == 4 +INSTRUCTION = "convert to text2text" - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( +# Our support spilts are `train, val, test`. +UNEXPECTED_DATASET_DICTS_WITH_WRONG_SPLIT = [ + datasets.DatasetDict( { - "input_col": ["1", "2", "3"], - "output_col": ["b", "a", "a"], + "full": datasets.Dataset.from_dict( + {"input_col": ["foo", "bar"], "explain_col": ["abc","xyz"], "output_col": ["baz", "qux"]} + ) } - ) - - # Verify the generated dataset matches the expected dataset. - assert list(generated_dataset) == list(expected_dataset) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_forth_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the forth batch.""" - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=5, - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples=5, - split=DatasetSplit.TRAIN, - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 4 - assert dataset_generator.api_call_counter == 5 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( + ), + datasets.DatasetDict( { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], + "train": datasets.Dataset.from_dict( + {"input_col": ["spam", "eggs"], "explain_col": ["lmn","opq"], "output_col": ["ham", "sau"]} + ) } - ) + ), +] - # Verify the generated dataset matches the expected dataset. - assert list(generated_dataset) == list(expected_dataset) +# Our support columns are `input_col, output_col`. +UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + {"input_col": ["foo", "bar"], "explain_col": ["abc","xyz"], "output_col": ["baz", "qux"]} + ) + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + {"input_col": ["spam", "eggs"], "explain_col": ["lmn","opq"], "output": ["ham", "sau"]} + ) + } + ), +] -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_unlimited_api_calls(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods and unlimited API calls.""" - # Initialize the PromptBasedDatasetGenerator with - # specific settings and unlimited API calls. - dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - ) +def test_the_logging_for_provide_unnecessary_eos_token_for_t5(): + """Test the logger.info for unnecessary eos token for T5 model is logged.""" + _, t5_tokenizer = create_t5_model_and_tokenizer() - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples=5, - split=DatasetSplit.TRAIN, - ) + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + _ = TextualizeProcessor(has_encoder=True, eos_token=t5_tokenizer.eos_token) + mock_info.assert_called_once_with( + "The T5 tokenizer automatically adds eos token in the end of sequence when tokenizing. So the eos_token of encoder-decoder model tokenizer is unnecessary." # noqa E501 + ) + mock_warning.assert_not_called() + gc.collect() - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 4 - assert dataset_generator.api_call_counter == 5 - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( +def test_the_logging_for_eos_token_required_for_gpt(): + """Test the logger.warning for requiring eos token for GPT model is logged.""" + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + _ = TextualizeProcessor(has_encoder=False) + mock_info.assert_not_called() + mock_warning.assert_called_once_with( + "The autoregressive model tokenizer does not automatically add eos token in the end of the sequence. So the `eos_token` of the autoregressive model is required." # noqa E501 + ) + gc.collect() + + +def test_dataset_processor_t5_style(): + """Test the `process_dataset_dict` function of T5-type `TextualizeProcessor`.""" + t5_processor = TextualizeProcessor(has_encoder=True) + raw_dataset_dicts = deepcopy(DATASET_DICTS) + t5_modified_dataset_dicts = t5_processor.process_dataset_dict( + INSTRUCTION, DATASET_DICTS + ) + # Ensure the dataset_dicts themselves are the same after processing. + for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): + assert list(raw["train"]) == list(origin["train"]) + if "val" in raw: + assert list(raw["val"]) == list(origin["val"]) + if "test" in raw: + assert list(raw["test"]) == list(origin["test"]) + t5_expected_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + ], + "model_output": ["baz", "qux"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + ], + "model_output": ["baz", "qux"], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + ], + "model_output": ["ham", "sau"], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + ], + "model_output": ["ham", "sau"], + } + ), + } + ), + ] + for exp, act in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): + assert list(exp["train"]) == list(act["train"]) + if "val" in exp: + assert list(exp["val"]) == list(act["val"]) + if "test" in exp: + assert list(exp["test"]) == list(act["test"]) + gc.collect() + + +def test_dataset_processor_with_numerical_column(): + """Test process_dataset_dict with numerical column values.""" + t5_processor = TextualizeProcessor(has_encoder=True) + raw_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "explain_col": ["abc","xyz"], + "output_col": ["baz", "qux"], + } + ), + "test": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn","opq"], + "output_col": ["ham", "sau"], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "input_col": ["foo", "bar"], + "explain_col": ["abc","xyz"], + "output_col": [0, 1], + } + ), + "test": datasets.Dataset.from_dict( + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn","opq"], + "output_col": [1, 2], + } + ), + } + ), + ] + t5_modified_dataset_dicts = t5_processor.process_dataset_dict( + INSTRUCTION, raw_dataset_dicts + ) + expected_dataset_dict = datasets.DatasetDict( { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + ], + "model_output": ["baz", "qux", "0", "1"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + ], + "model_output": ["ham", "sau", "1", "2"], + } + ), } ) - - # Verify the generated dataset matches the expected dataset. - assert list(generated_dataset) == list(expected_dataset) + training_datasets = [] + test_datasets = [] + for modified_dataset_dict in t5_modified_dataset_dicts: + training_datasets.append(modified_dataset_dict["train"]) + test_datasets.append(modified_dataset_dict["test"]) + + concatenated_training_dataset = datasets.concatenate_datasets(training_datasets) + concatenated_test_dataset = datasets.concatenate_datasets(test_datasets) + actual_dataset_dict = datasets.DatasetDict( + {"train": concatenated_training_dataset, "test": concatenated_test_dataset} + ) + assert list(expected_dataset_dict["train"]) == list(actual_dataset_dict["train"]) + assert list(expected_dataset_dict["test"]) == list(actual_dataset_dict["test"]) + + +def test_dataset_processor_decoder_only_style(): + """Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`.""" + _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() + gpt_processor = TextualizeProcessor( + has_encoder=False, eos_token=gpt2_tokenizer.eos_token + ) + raw_dataset_dicts = deepcopy(DATASET_DICTS) + gpt_modified_dataset_dicts = gpt_processor.process_dataset_dict( + INSTRUCTION, DATASET_DICTS + ) + # Ensure the dataset_dicts themselves are the same after processing. + for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): + assert list(raw["train"]) == list(origin["train"]) + if "val" in raw: + assert list(raw["val"]) == list(origin["val"]) + if "test" in raw: + assert list(raw["test"]) == list(origin["test"]) + # Check that the modified dataset dicts have the expected content + gpt_expected_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\nbaz<|endoftext|>", # noqa: E501 + "convert to text2text\nExample:\nbar\nLabel:\nqux<|endoftext|>", # noqa: E501 + ], + "model_output": ["baz<|endoftext|>", "qux<|endoftext|>"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + "convert to text2text\nExample:\nbar\nLabel:\n", + ], + "model_output": ["baz", "qux"], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\nham<|endoftext|>", # noqa: E501 + "convert to text2text\nExample:\neggs\nLabel:\nsau<|endoftext|>", # noqa: E501 + ], + "model_output": ["ham<|endoftext|>", "sau<|endoftext|>"], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nspam\nLabel:\n", + "convert to text2text\nExample:\neggs\nLabel:\n", + ], + "model_output": ["ham", "sau"], + } + ), + } + ), + ] + for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) + + +def test_unexpected_dataset_split(): + """Test the error handler for unexpercted dataset split.""" + with pytest.raises(ValueError) as exc_info: + _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() + gpt_processor = TextualizeProcessor( + has_encoder=False, eos_token=gpt2_tokenizer.eos_token + ) + _ = gpt_processor.process_dataset_dict( + INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_SPLIT + ) + assert str(exc_info.value) == ("Datset split must be in train/val/test.") + gc.collect() -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions(length=5).mock_completions, -) -def test_generator_with_filter_to_generate_datasetdict(mocked_generate_example): - """Test with filter methods to generate a DatasetDict.""" - # Initialize the PromptBasedDatasetGenerator with - # specific settings and limited API calls. - dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, - max_batch_size=2, - responses_per_request=3, - max_api_calls=7, - ) - - # Generate the DatasetDict using the initialized generator. - generated_dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), - num_examples={ - DatasetSplit.TRAIN: 4, - DatasetSplit.VAL: 4, - DatasetSplit.TEST: 2, - }, - ) +def test_unexpected_columns(): + """Test the error handler for unexpercted dataset columns.""" + with pytest.raises(ValueError) as exc_info: + _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() + gpt_processor = TextualizeProcessor( + has_encoder=False, eos_token=gpt2_tokenizer.eos_token + ) + _ = gpt_processor.process_dataset_dict( + INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS + ) + assert str(exc_info.value) == ( + "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." + ) + gc.collect() - # Assertions for API call count and dataset - # dictionaries matching the expected results. - assert mocked_generate_example.call_count == 5 - assert dataset_generator.api_call_counter == 7 - # Define the expected dataset dictionaries - # based on the given mock responses. - expected_dataset_dict = datasets.DatasetDict( +DATASET_DICTS_WITH_EMPTY_COLUMNS = [ + datasets.DatasetDict( { - "train": Dataset.from_dict( + "train": datasets.Dataset.from_dict( { - "input_col": ["1", "2", "3", "4"], - "output_col": ["b", "a", "a", "c"], + "input_col": ["foo", "", "test"], + "explain_col": ["abc","","xyz"], + "output_col": ["", "qux", "key"], } ), - "val": Dataset.from_dict( + "test": datasets.Dataset.from_dict( { - "input_col": ["1", "2"], - "output_col": ["a", "a"], + "input_col": ["foo", ""], + "explain_col": ["abc",""], + "output_col": ["baz", "qux"], } ), - "test": Dataset.from_dict( + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( { - "input_col": [], - "output_col": [], + "input_col": ["", ""], + "explain_col": ["abc","xyz"], + "output_col": ["ham", "sau"], } ), } - ) - - # Verify the generated DatasetDict matches the expected DatasetDict. - assert list(generated_dataset_dict["train"]) == list(expected_dataset_dict["train"]) - assert list(generated_dataset_dict["val"]) == list(expected_dataset_dict["val"]) - assert list(generated_dataset_dict["test"]) == list(expected_dataset_dict["test"]) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generator_max_api_calls_dict(mocked_generate_example): - """Test generation of a dataset dict where we hit max api calls.""" - # Refresh the call_count and create a new limited_dataset_generator. - dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, - max_api_calls=13, - ) - - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = { - DatasetSplit.TRAIN: 50, - DatasetSplit.VAL: 24, - DatasetSplit.TEST: 26, + ), +] + + +def test_empty_filter_t5_type(): + """Test that examples with empty input_col or output_col are discarded.""" + t5_processor = TextualizeProcessor(has_encoder=True) + t5_modified_dataset_dicts = t5_processor.process_dataset_dict( + INSTRUCTION, DATASET_DICTS_WITH_EMPTY_COLUMNS + ) + t5_expected_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\ntest\nLabel:\n", + ], + "model_output": ["key"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + ], + "model_output": [ + "baz", + ], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [], + "model_output": [], + } + ), + } + ), + ] + for exp, modified in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) + + +def test_empty_filter_decoder_only_style(): + """Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`.""" + _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() + gpt_processor = TextualizeProcessor( + has_encoder=False, eos_token=gpt2_tokenizer.eos_token + ) + gpt_modified_dataset_dicts = gpt_processor.process_dataset_dict( + INSTRUCTION, DATASET_DICTS_WITH_EMPTY_COLUMNS + ) + + # Check that the modified dataset dicts have the expected content + gpt_expected_dataset_dicts = [ + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\ntest\nLabel:\nkey<|endoftext|>", # noqa: E501 + ], + "model_output": ["key<|endoftext|>"], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + "convert to text2text\nExample:\nfoo\nLabel:\n", + ], + "model_output": ["baz"], + } + ), + } + ), + datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [], + "model_output": [], + } + ), + } + ), + ] + for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) + gc.collect() + + +GENERATED_DATASET = datasets.Dataset.from_dict( + { + "input_col": list(range(10000)), + "explain_col": ['a'] * 10000, + "output_col": list(range(10000, 20000)), } - - dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec=prompt_spec, - num_examples=num_examples, - ) - - # Since the max_api_calls is 13, the limited_dataset_generator cannot - # generate the whole dataset_dict and will call the API 13 times. - assert dataset_generator.api_call_counter == 13 - - # The train split has 50 examples, so it will call the API 10 times and call - # generate_batch_completion 2 times. - # The validation split has 24 examples, but there are only 3 API calls - # left, so it will call the API 3 times and call - # generate_batch_completion 1 time. - # The test split has 26 examples, but there are no more API calls left, - # so it will not call generate_batch_completion. - assert mocked_generate_example.call_count == 2 + 1 + 0 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the generated_dataset_dict will contain (50, 15, 0) examples. - assert len(dataset_dict["train"]) == 50 - assert len(dataset_dict["val"]) == 15 - assert len(dataset_dict["test"]) == 0 - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_WRONG_KEY_EXAMPLE, ) -def test_wrong_key_example(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns wrong keys.""" - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=False - ) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = 1 - split = DatasetSplit.TRAIN - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, num_examples, split - ) - assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) - assert list(expected_dataset) == list(generated_dataset) - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_INVALID_JSON, -) -def test_invalid_json_response(mocked_generate_example): - """Test when the agent returns invalid JSON responses.""" - # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. - dataset_generator = PromptBasedDatasetGenerator(3, filter_duplicated_examples=False) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = 1 - split = DatasetSplit.VAL - dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) - assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) - assert list(dataset) == list(expected_dataset) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=UnknownGpt3Exception(), +RETRIEVED_TRAIN_DATASET = datasets.Dataset.from_dict( + { + "input_col": list(range(20000, 30000)), + "explain_col": ['a'] * 10000, + "output_col": list(range(30000, 40000)), + } ) -def test_unexpected_examples_of_gpt(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns unexpected examples.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with pytest.raises(UnknownGpt3Exception): - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=False - ) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - num_examples = 1 - split = DatasetSplit.TEST - _ = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) - assert mocked_generate_example.call_count == 1 - - -def test_filter_with_duplicate_inputs_unique_outputs(): - """Test filtering with duplicate inputs, unique outputs.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="E"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - ] - filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) - expected_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - assert sorted(expected_examples) == sorted(filtered_examples) - - -def test_filter_duplicate_inputs_duplicate_outputs(): - """Test constructing a map with duplicate inputs and duplicate outputs.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="C"), - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="G"), - Example(input_col="apple", output_col="A"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="F"), - ] - filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) - expected_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - assert expected_examples == filtered_examples - - -def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): - """Test constructing a map with unique inputs and outputs.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) - assert generated_examples == filtered_examples - - -def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): - """Test constructing a map with empty inputs and outputs.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) - generated_examples = [] - filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) - assert generated_examples == filtered_examples - - -def test_compute_batch_size_with_limited_max_api_calls(): - """Test the batch size computation with limited max API calls.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(max_api_calls=28) - data_generator.api_call_counter = 26 - # Default batch size and responses_per_request are both 5. - # So each batch should contain 25 examples. - - # At least (125 - 110) / 5 = 3 API calls needed to get - # more than 125 examples. - - batch_size = data_generator.compute_batch_size( - num_examples=125, generated_dataset_size=110 - ) - assert ( - batch_size - == data_generator.max_api_calls - data_generator.api_call_counter - == 28 - 26 - ) - - data_generator.api_call_counter = 20 - batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) - assert ( - batch_size - == (125 - 110) / data_generator.responses_per_request - == (125 - 110) / 5 - ) - data_generator.api_call_counter = 0 - batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) - assert batch_size == data_generator.max_batch_size +DATASET_LIST = [GENERATED_DATASET, RETRIEVED_TRAIN_DATASET] -def test_compute_batch_size_with_unlimited_max_api_calls(): - """Test the batch size computation with unlimited max API calls.""" - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator() - # Default batch size and responses_per_request are both 5. - # So each batch should contain 25 examples. - - # At least (125 - 110) / 5 = 3 API calls needed to get - # more than 125 examples. - - batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) - assert ( - batch_size - == (125 - 110) / data_generator.responses_per_request - == (125 - 110) / 5 +def test_raise_value_error_of_process_dataset_lists(): + """Test that the ValueError is correctly raised.""" + _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() + gpt_processor = TextualizeProcessor( + has_encoder=False, eos_token=gpt2_tokenizer.eos_token ) - - batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) - assert batch_size == data_generator.max_batch_size == 5 - - -def test_extract_responses(): - """Test the extract_responses function of DatasetGenerator.""" - mock_completion_1 = MockCompletion() - mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, - ] - mock_completion_2 = MockCompletion() - mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - # Note that the following choice miss the right quote of JSON. - # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, - ] - mock_completion_3 = MockCompletion() - mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, - ] - # choices should be list of dicts. So mock_completion_4 - # is invalid. Which will be discarded and log a warning. - mock_completion_4 = MockCompletion() - mock_completion_4.choices = None - - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) - generated_examples = [] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - data_generator.extract_and_append_responses( - [mock_completion_1, mock_completion_2], generated_examples - ) - mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 - ) - # There are 5 valid examples. Each input - # and output will be logged once as info. - assert mock_info.call_count == 5 * 2 - - # The second choice in mock_completion_2 - # is invalid. So it should be discarded. - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - ] - data_generator.extract_and_append_responses([mock_completion_3], generated_examples) - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - data_generator.extract_and_append_responses( - [mock_completion_4], generated_examples - ) - mock_warning.assert_called_once_with( - "Error happened when parsing API completion: " - ) - mock_info.assert_not_called() - # The generated_examples should be the same. - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - - -def test_extract_some_empty_responses(): - """Test the extract_responses function correctly handle empty responses.""" - mock_completion_1 = MockCompletion() - mock_completion_1.choices = [ - # Note that this choice's input is empty. So it should be discarded. - {"message": {"content": '{"input": "", "output": "a"}'}}, - {"message": {"content": '{"input": "5", "output": "b"}'}}, - # Note that this choice's output is empty. So it should be discarded. - {"message": {"content": '{"input": "1", "output": ""}'}}, - ] - mock_completion_2 = MockCompletion() - mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - # Note that the following choice misses the right quote of JSON. - # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, - ] - mock_completion_3 = MockCompletion() - mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, - ] - # choices should be list of dicts. So mock_completion_4 - # is invalid. Which will be discarded and log a warning. - mock_completion_4 = MockCompletion() - mock_completion_4.choices = None - - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - generated_examples = [] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - data_generator.extract_and_append_responses( - [mock_completion_1, mock_completion_2], generated_examples - ) - mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 - ) - # There are 3 valid examples in [mock_completion_1, - # mock_completion_2] Each input - # and output will be logged once as info. - # And there are 2 examples with empty - # input or output, which should be discarded - # and be logged as info. - assert mock_info.call_count == 3 * 2 + 2 - - # The second choice in mock_completion_2 - # is invalid. So it should be discarded. - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - ] - data_generator.extract_and_append_responses( - [mock_completion_3], generated_examples - ) - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - data_generator.extract_and_append_responses( - [mock_completion_4], generated_examples - ) - mock_warning.assert_called_once_with( - "Error happened when parsing API completion: " - ) - mock_info.assert_not_called() - # The generated_examples should be the same. - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - - -def test_initialize_dataset_generator_with_dynamic_temperature(): - """Test the correct initialization of the dynamic temperature strategy.""" - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator( - cache_root=cache_dir, initial_temperature=-0.2 - ) + with pytest.raises(ValueError) as exc_info: + gpt_processor.process_dataset_lists(INSTRUCTION, DATASET_LIST, 0.8, 0.2) error_info = exc_info.value.args[0] assert ( error_info - == "initial_temperature must be >= 0, but self.initial_temperature=-0.2" + == "train_proportion 0.8 + val_proportion 0.2 must be less than 1." ) - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator(cache_root=cache_dir, max_temperature=2.3) - error_info = exc_info.value.args[0] - assert ( - error_info - == "max_temperature must be <= 2,0, but self.max_temperature=2.3" - ) - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator( - cache_root=cache_dir, max_temperature=1.2, initial_temperature=1.5 - ) - error_info = exc_info.value.args[0] - assert ( - error_info - == "self.initial_temperature=1.5 must be <= self.max_temperature=1.2" - ) + t5_processor = TextualizeProcessor(has_encoder=True) + with pytest.raises(ValueError) as exc_info: + t5_processor.process_dataset_lists(INSTRUCTION, DATASET_LIST, 0.8, 0.2) + error_info = exc_info.value.args[0] + assert ( + error_info + == "train_proportion 0.8 + val_proportion 0.2 must be less than 1." + ) -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_dataset_generator_terminates(mocked_generate_example): - """Check to make sure that the dataset generator terminates.""" - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - dataset_generator = PromptBasedDatasetGenerator( - initial_temperature=0.3, - max_temperature=1.4, - responses_per_request=3, - max_api_calls=10000, - requests_per_minute=80, - filter_duplicated_examples=False, +def test_process_dataset_lists(): + """Test the `process_dataset_lists` function.""" + processor = TextualizeProcessor(has_encoder=True) + modified_dataset_dicts = processor.process_dataset_lists( + INSTRUCTION, DATASET_LIST, 0.6, 0.2 ) - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, 100, split=DatasetSplit.TRAIN + expected_modified_generated_dataset_dict = datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(6000) + ], + "model_output": [f"{output}" for output in range(10000, 16000)], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(6000, 8000) + ], + "model_output": [f"{output}" for output in range(16000, 18000)], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(8000, 10000) + ], + "model_output": [f"{output}" for output in range(18000, 20000)], + } + ), + } ) - generated_df = generated_dataset.to_pandas() - assert len(generated_dataset) == 100 - assert list(generated_df.columns) == ["input_col", "output_col"] - - -def test_generate_dataset_agent_switch(): - """Test if dataset generation can use a user-set API agent.""" - my_agent = MockAPIAgent( - default_content='{"input": "This is input.", "output": "This is an output."}' + expected_modified_retrieved_dataset_dict = datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(20000, 26000) + ], + "model_output": [f"{output}" for output in range(30000, 36000)], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(26000, 28000) + ], + "model_output": [f"{output}" for output in range(36000, 38000)], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(28000, 30000) + ], + "model_output": [f"{output}" for output in range(38000, 40000)], + } + ), + } ) - with temp_setattr(api_tools, "default_api_agent", my_agent): - prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION) - dataset_generator = PromptBasedDatasetGenerator( - initial_temperature=0.3, - max_temperature=1.4, - responses_per_request=1, - max_api_calls=100, - requests_per_minute=80, - filter_duplicated_examples=False, - ) - dataset_generator.generate_dataset_split( - prompt_spec, 100, split=DatasetSplit.TRAIN - ) - # 100 outputs, and each batch has 5 outputs so 20 api calls - assert my_agent.generate_batch_call_counter == 20 + for exp, modified in zip( + [ + expected_modified_generated_dataset_dict, + expected_modified_retrieved_dataset_dict, + ], + modified_dataset_dicts, + ): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) + + +def test_process_dataset_lists_with_maximum_example_num(): + """Test the maximum_example_num parameter.""" + processor = TextualizeProcessor(has_encoder=True) + modified_dataset_dicts = processor.process_dataset_lists( + INSTRUCTION, DATASET_LIST, 0.6, 0.2, {"train": 3000, "val": 500, "test": 1000} + ) + # Before applying the maximum_example_num, train_num = 6000, + # val_num = 2000, test_num = 2000. + # After applying the maximum_example_num, train_num = 3000, + # val_num = 2000, test_num = 2000. + expected_modified_generated_dataset_dict = datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(3000) + ], + "model_output": [f"{output}" for output in range(10000, 13000)], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(3000, 3500) + ], + "model_output": [f"{output}" for output in range(13000, 13500)], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(3500, 4500) + ], + "model_output": [f"{output}" for output in range(13500, 14500)], + } + ), + } + ) + expected_modified_retrieved_dataset_dict = datasets.DatasetDict( + { + "train": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(20000, 23000) + ], + "model_output": [f"{output}" for output in range(30000, 33000)], + } + ), + "val": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(23000, 23500) + ], + "model_output": [f"{output}" for output in range(33000, 33500)], + } + ), + "test": datasets.Dataset.from_dict( + { + "model_input": [ + f"convert to text2text\nExample:\n{input}\nLabel:\n" + for input in range(23500, 24500) + ], + "model_output": [f"{output}" for output in range(33500, 34500)], + } + ), + } + ) + for exp, modified in zip( + [ + expected_modified_generated_dataset_dict, + expected_modified_retrieved_dataset_dict, + ], + modified_dataset_dicts, + ): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) From 16fdc0eb6d190e629ae927b768b070ceb9cb7e6e Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:20:54 -0500 Subject: [PATCH 14/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 10b33a3aa..c9825c46a 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -19,14 +19,14 @@ "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc","xyz"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc","xyz"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), @@ -37,14 +37,14 @@ "train": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn","opq"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), "val": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn","opq"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -60,14 +60,14 @@ datasets.DatasetDict( { "full": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "explain_col": ["abc","xyz"], "output_col": ["baz", "qux"]} + {"input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"]} # noqa E501 ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "explain_col": ["lmn","opq"], "output_col": ["ham", "sau"]} + {"input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"]} # noqa E501 ) } ), @@ -78,14 +78,14 @@ datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "explain_col": ["abc","xyz"], "output_col": ["baz", "qux"]} + {"input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"]} # noqa E501 ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "explain_col": ["lmn","opq"], "output": ["ham", "sau"]} + {"input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], "output": ["ham", "sau"]} # noqa E501 ) } ), @@ -198,14 +198,14 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc","xyz"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn","opq"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -216,14 +216,14 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc","xyz"], + "explain_col": ["abc", "xyz"], "output_col": [0, 1], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn","opq"], + "explain_col": ["lmn", "opq"], "output_col": [1, 2], } ), @@ -371,7 +371,7 @@ def test_unexpected_columns(): INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS ) assert str(exc_info.value) == ( - "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." + "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 ) gc.collect() @@ -382,14 +382,14 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "", "test"], - "explain_col": ["abc","","xyz"], + "explain_col": ["abc", "", "xyz"], "output_col": ["", "qux", "key"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", ""], - "explain_col": ["abc",""], + "explain_col": ["abc", ""], "output_col": ["baz", "qux"], } ), @@ -400,7 +400,7 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["", ""], - "explain_col": ["abc","xyz"], + "explain_col": ["abc", "xyz"], "output_col": ["ham", "sau"], } ), From d48d252b0226839036e685db9657741c110d8515 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:25:49 -0500 Subject: [PATCH 15/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index c9825c46a..d2125024d 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -60,14 +60,22 @@ datasets.DatasetDict( { "full": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"]} # noqa E501 + { + "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], + "output_col": ["baz", "qux"] + } ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"]} # noqa E501 + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], + "output_col": ["ham", "sau"] + } ) } ), @@ -78,14 +86,22 @@ datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"]} # noqa E501 + { + "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], + "output_col": ["baz", "qux"] + } ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], "output": ["ham", "sau"]} # noqa E501 + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], + "output": ["ham", "sau"] + } ) } ), From 28aadad928b417c4a87e0d2bd7e81aa36fcca328 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:28:00 -0500 Subject: [PATCH 16/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index d2125024d..04c54a058 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -387,7 +387,7 @@ def test_unexpected_columns(): INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS ) assert str(exc_info.value) == ( - "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 + "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 ) gc.collect() @@ -528,7 +528,7 @@ def test_empty_filter_decoder_only_style(): GENERATED_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(10000)), - "explain_col": ['a'] * 10000, + "explain_col": ["a"] * 10000, "output_col": list(range(10000, 20000)), } ) @@ -536,7 +536,7 @@ def test_empty_filter_decoder_only_style(): RETRIEVED_TRAIN_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(20000, 30000)), - "explain_col": ['a'] * 10000, + "explain_col": ["a"] * 10000, "output_col": list(range(30000, 40000)), } ) From 49caf602d54b384d7066da37807dcf5e24be6043 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:29:59 -0500 Subject: [PATCH 17/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 04c54a058..cb49b009d 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -63,8 +63,8 @@ { "input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"] - } + "output_col": ["baz", "qux"], + } ) } ), @@ -74,7 +74,7 @@ { "input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"] + "output_col": ["ham", "sau"], } ) } @@ -89,7 +89,7 @@ { "input_col": ["foo", "bar"], "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"] + "output_col": ["baz", "qux"], } ) } @@ -100,7 +100,7 @@ { "input_col": ["spam", "eggs"], "explain_col": ["lmn", "opq"], - "output": ["ham", "sau"] + "output": ["ham", "sau"], } ) } From a4eebb8fadbf44acdf4ec318be18dac39d21b444 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 07:50:38 -0500 Subject: [PATCH 18/26] Update dataset_processor_test.py --- tests/dataset_processor_test.py | 39 ++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/tests/dataset_processor_test.py b/tests/dataset_processor_test.py index ba7ec29d3..cb49b009d 100644 --- a/tests/dataset_processor_test.py +++ b/tests/dataset_processor_test.py @@ -19,12 +19,14 @@ "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), @@ -35,12 +37,14 @@ "train": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), "val": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -56,14 +60,22 @@ datasets.DatasetDict( { "full": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "output_col": ["baz", "qux"]} + { + "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], + "output_col": ["baz", "qux"], + } ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "output_col": ["ham", "sau"]} + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], + "output_col": ["ham", "sau"], + } ) } ), @@ -74,14 +86,22 @@ datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["foo", "bar"], "output_col": ["baz", "qux"]} + { + "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], + "output_col": ["baz", "qux"], + } ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - {"input_col": ["spam", "eggs"], "output": ["ham", "sau"]} + { + "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], + "output": ["ham", "sau"], + } ) } ), @@ -194,12 +214,14 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -210,12 +232,14 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], + "explain_col": ["abc", "xyz"], "output_col": [0, 1], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], + "explain_col": ["lmn", "opq"], "output_col": [1, 2], } ), @@ -363,7 +387,7 @@ def test_unexpected_columns(): INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS ) assert str(exc_info.value) == ( - "Example dictionary must have 'input_col' and 'output_col' keys." + "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 ) gc.collect() @@ -374,12 +398,14 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "", "test"], + "explain_col": ["abc", "", "xyz"], "output_col": ["", "qux", "key"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", ""], + "explain_col": ["abc", ""], "output_col": ["baz", "qux"], } ), @@ -390,6 +416,7 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["", ""], + "explain_col": ["abc", "xyz"], "output_col": ["ham", "sau"], } ), @@ -501,6 +528,7 @@ def test_empty_filter_decoder_only_style(): GENERATED_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(10000)), + "explain_col": ["a"] * 10000, "output_col": list(range(10000, 20000)), } ) @@ -508,6 +536,7 @@ def test_empty_filter_decoder_only_style(): RETRIEVED_TRAIN_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(20000, 30000)), + "explain_col": ["a"] * 10000, "output_col": list(range(30000, 40000)), } ) From 1b0ffc7c6cb27346654c4d0ddcc2ff45d23685b2 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 08:07:49 -0500 Subject: [PATCH 19/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 1592 ++++++++++++++++++------------- 1 file changed, 918 insertions(+), 674 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index cb49b009d..14d077a6e 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -1,736 +1,980 @@ -"""Testing TextualizeProcessor.""" +"""Testing DatasetGenerator through PromptBasedDatasetGenerator.""" -import gc import logging -from copy import deepcopy +import os +import tempfile +from functools import partial from unittest.mock import patch import datasets import pytest +from datasets import Dataset -from prompt2model.dataset_processor.textualize import TextualizeProcessor -from test_helpers import create_gpt2_model_and_tokenizer, create_t5_model_and_tokenizer +from prompt2model.dataset_generator.base import DatasetSplit +from prompt2model.dataset_generator.prompt_based import ( + Example, + PromptBasedDatasetGenerator, +) +from prompt2model.prompt_parser import MockPromptSpec, TaskType +from prompt2model.utils import api_tools +from test_helpers import ( + MockCompletion, + UnknownGpt3Exception, + mock_batch_api_response_identical_completions, +) +from test_helpers.mock_api import MockAPIAgent, MockBatchDifferentCompletions +from test_helpers.test_utils import temp_setattr -logger = logging.getLogger("DatasetProcessor") +logger = logging.getLogger("DatasetGenerator") -DATASET_DICTS = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } - ), - "test": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"], - } - ), - "val": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"], - } - ), - } - ), -] +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generate_dataset(mocked_generate_example): + """Test the `generate_dataset_split()` function of `PromptBasedDatasetGenerator`.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + split = DatasetSplit.TRAIN + num_examples = 29 + # If num_examples >= max_api_calls, the returned dataset's + # length will be less than or equal to max_api_calls. + dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + # Since each API call would return one completion object with 5 responses + # and some of the responses are invalid JSON objects, the upper bound of + # the length of the dataset is num_examples + 5, where 5 is the + # default number of responses per API call. + assert len(dataset) < num_examples + 5 + expected_columns = {"input_col", "output_col"} + assert set(dataset.column_names) == expected_columns + return dataset + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generate_dataset_dict(mocked_generate_example): + """Test the `generate_dataset_dict()` function of `PromptBasedDatasetGenerator`.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, + } + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + assert set(dataset_dict.keys()) == {"train", "val", "test"} + for split, num in num_examples.items(): + # As explained previously, the upper bound of the length of + # generated dataset is num_examples + 5, where + # 5 is the default number of responses per API call. + assert len(dataset_dict[split.value]) < num + 5 + expected_columns = {"input_col", "output_col"} + for dataset in dataset_dict.values(): + assert set(dataset.column_names) == expected_columns + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_without_filter(mocked_generate_example): + """Unlimited dataset generation using the PromptBasedDatasetGenerator.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + dataset = dataset_generator.generate_dataset_split( + MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN + ) + assert len(dataset) == 29 + # The default responses_per_request is 5. So each API call will return + # 5 responses, i.e. 5 choices in openai.Completion.choices. + # Each API call will return 5 responses, and each response is a valid JSON. + # So the unlimited_dataset_generator will call the API 6 times. + assert dataset_generator.api_call_counter == 6 + # The default batch_size is 5. So generate_batch_completion + # will be called 2 times with first batch_size = 5 and second batch_size = 1. + assert mocked_generate_example.call_count == 2 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_without_filter_dict(mocked_generate_example): + """Test generation of a dataset dict.""" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, + } + + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + assert set(dataset_dict.keys()) == {"train", "val", "test"} + for split, num in num_examples.items(): + # As explained previously, the upper bound of the length of + # generated dataset is num_examples + 5, where + # 5 is the default number of responses per API call. + assert len(dataset_dict[split.value]) < num + 5 + expected_columns = {"input_col", "explain_col", "output_col"} + for dataset in dataset_dict.values(): + assert set(dataset.column_names) == expected_columns + + # Each API call returns five responses. So the dataset_generator will + # call the API (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) = 21 times. + assert dataset_generator.api_call_counter == (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) + # The default batch_size is 5. So generate_batch_completion + # will be called 2 times for 50 examples in the train split, + # 1 time for 24 examples in the validation split, + # and 2 times for 26 examples in the test split. + assert mocked_generate_example.call_count == 2 + 1 + 2 + + # Each API call returns 5 responses, and each response is a valid JSON. + # So the dataset_dict will contain (50, 25, 30) examples. + assert len(dataset_dict["train"]) == 50 + assert len(dataset_dict["val"]) == 24 + assert len(dataset_dict["test"]) == 26 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_max_api_calls(mocked_generate_example): + """Test generation when num_examples >= max_api_calls.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + dataset = dataset_generator.generate_dataset_split( + MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN + ) + # The max_api_calls is 3. So the limited_dataset_generator calls the + # API 3 times. Each API call returns 5 responses. So the + # limited_dataset_generator will have 3 * 5 = 15 examples. + assert len(dataset) == 15 + + # The default batch_size is 5. So generate_batch_completion + # will be called only once. + assert mocked_generate_example.call_count == 1 + + # Each API call returns 5 responses, so the limited_dataset_generator + # will use up all the available API calls. + assert dataset_generator.api_call_counter == 3 + # Each API call returns 5 responses, and each response is a valid JSON. + # So the dataset will contain 15 examples. + assert len(dataset) == 15 -INSTRUCTION = "convert to text2text" -# Our support spilts are `train, val, test`. -UNEXPECTED_DATASET_DICTS_WITH_WRONG_SPLIT = [ - datasets.DatasetDict( +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_first_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the first batch.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=2, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 1 + assert dataset_generator.api_call_counter == 2 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( { - "full": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } - ) + "input_col": ["1", "2"], + "explain_col": ["x", "x"], + "output_col": ["a", "a"], } - ), - datasets.DatasetDict( + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_second_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the second batch. + + This test verifies the behavior of the PromptBasedDatasetGenerator with filter + methods in the second batch of API calls. It initializes an + PromptBasedDatasetGenerator with specific settings, limiting the number of + API calls to 3. After running the generation process, the test checks + whether the generated dataset matches the expected result after the + second API call. The test also ensures that the number of calls to the + API mock matches the expected number. + + Note: The first API call's max_batch_size is 2, generating 6 responses. + The second API call's max_batch_size is 1, generating 3 responses. + + Args: + mocked_generate_example (MagicMock): The patched function representing the + @patch decorator for generating example responses. + """ + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 2 + assert dataset_generator.api_call_counter == 3 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( { - "train": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"], - } - ) + "input_col": ["1", "2", "3"], + "explain_col": ["x", "x", "x"], + "output_col": ["a", "a", "a"], } - ), -] + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + -# Our support columns are `input_col, output_col`. -UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS = [ - datasets.DatasetDict( +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_third_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the third batch. + + This test verifies the behavior of the PromptBasedDatasetGenerator with + filter methods in the third batch of API calls. It initializes an + PromptBasedDatasetGenerator with specific settings, limiting the number + of API calls to 4. After running the generation process, the test + checks whether the generated dataset matches the expected + result after the third API call. The test also ensures that the + number of calls to the API mock matches the expected number. + + Note: The first API call's max_batch_size is 2, generating 6 responses. + The second API call's max_batch_size is 1, generating 3 responses. + The third API call's max_batch_size is 1, generating 3 responses. + + Args: + mocked_generate_example (MagicMock): The patched function representing the + @patch decorator for generating example responses. + """ + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=4, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 3 + assert dataset_generator.api_call_counter == 4 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( { - "train": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } - ) + "input_col": ["1", "2", "3"], + "explain_col": ["x", "x", "x"], + "output_col": ["b", "a", "a"], } - ), - datasets.DatasetDict( + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_forth_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the forth batch.""" + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=5, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 4 + assert dataset_generator.api_call_counter == 5 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( { - "train": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output": ["ham", "sau"], - } - ) + "input_col": ["1", "2", "3", "4", "5"], + "explain_col": ["x", "x", "x", "x", "x"], + "output_col": ["b", "a", "a", "c", "a"], } - ), -] + ) + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) -def test_the_logging_for_provide_unnecessary_eos_token_for_t5(): - """Test the logger.info for unnecessary eos token for T5 model is logged.""" - _, t5_tokenizer = create_t5_model_and_tokenizer() - - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - _ = TextualizeProcessor(has_encoder=True, eos_token=t5_tokenizer.eos_token) - mock_info.assert_called_once_with( - "The T5 tokenizer automatically adds eos token in the end of sequence when tokenizing. So the eos_token of encoder-decoder model tokenizer is unnecessary." # noqa E501 - ) - mock_warning.assert_not_called() - gc.collect() +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_unlimited_api_calls(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods and unlimited API calls.""" + # Initialize the PromptBasedDatasetGenerator with + # specific settings and unlimited API calls. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) -def test_the_logging_for_eos_token_required_for_gpt(): - """Test the logger.warning for requiring eos token for GPT model is logged.""" - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - _ = TextualizeProcessor(has_encoder=False) - mock_info.assert_not_called() - mock_warning.assert_called_once_with( - "The autoregressive model tokenizer does not automatically add eos token in the end of the sequence. So the `eos_token` of the autoregressive model is required." # noqa E501 - ) - gc.collect() - - -def test_dataset_processor_t5_style(): - """Test the `process_dataset_dict` function of T5-type `TextualizeProcessor`.""" - t5_processor = TextualizeProcessor(has_encoder=True) - raw_dataset_dicts = deepcopy(DATASET_DICTS) - t5_modified_dataset_dicts = t5_processor.process_dataset_dict( - INSTRUCTION, DATASET_DICTS - ) - # Ensure the dataset_dicts themselves are the same after processing. - for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): - assert list(raw["train"]) == list(origin["train"]) - if "val" in raw: - assert list(raw["val"]) == list(origin["val"]) - if "test" in raw: - assert list(raw["test"]) == list(origin["test"]) - t5_expected_dataset_dicts = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - "convert to text2text\nExample:\nbar\nLabel:\n", - ], - "model_output": ["baz", "qux"], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - "convert to text2text\nExample:\nbar\nLabel:\n", - ], - "model_output": ["baz", "qux"], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nspam\nLabel:\n", - "convert to text2text\nExample:\neggs\nLabel:\n", - ], - "model_output": ["ham", "sau"], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nspam\nLabel:\n", - "convert to text2text\nExample:\neggs\nLabel:\n", - ], - "model_output": ["ham", "sau"], - } - ), - } - ), - ] - for exp, act in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): - assert list(exp["train"]) == list(act["train"]) - if "val" in exp: - assert list(exp["val"]) == list(act["val"]) - if "test" in exp: - assert list(exp["test"]) == list(act["test"]) - gc.collect() - - -def test_dataset_processor_with_numerical_column(): - """Test process_dataset_dict with numerical column values.""" - t5_processor = TextualizeProcessor(has_encoder=True) - raw_dataset_dicts = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } - ), - "test": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": [0, 1], - } - ), - "test": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": [1, 2], - } - ), - } - ), - ] - t5_modified_dataset_dicts = t5_processor.process_dataset_dict( - INSTRUCTION, raw_dataset_dicts + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, ) - expected_dataset_dict = datasets.DatasetDict( + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 4 + assert dataset_generator.api_call_counter == 5 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - "convert to text2text\nExample:\nbar\nLabel:\n", - "convert to text2text\nExample:\nfoo\nLabel:\n", - "convert to text2text\nExample:\nbar\nLabel:\n", - ], - "model_output": ["baz", "qux", "0", "1"], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nspam\nLabel:\n", - "convert to text2text\nExample:\neggs\nLabel:\n", - "convert to text2text\nExample:\nspam\nLabel:\n", - "convert to text2text\nExample:\neggs\nLabel:\n", - ], - "model_output": ["ham", "sau", "1", "2"], - } - ), + "input_col": ["1", "2", "3", "4", "5"], + "explain_col": ["x", "x", "x", "x", "x"], + "output_col": ["b", "a", "a", "c", "a"], } ) - training_datasets = [] - test_datasets = [] - for modified_dataset_dict in t5_modified_dataset_dicts: - training_datasets.append(modified_dataset_dict["train"]) - test_datasets.append(modified_dataset_dict["test"]) - - concatenated_training_dataset = datasets.concatenate_datasets(training_datasets) - concatenated_test_dataset = datasets.concatenate_datasets(test_datasets) - actual_dataset_dict = datasets.DatasetDict( - {"train": concatenated_training_dataset, "test": concatenated_test_dataset} - ) - assert list(expected_dataset_dict["train"]) == list(actual_dataset_dict["train"]) - assert list(expected_dataset_dict["test"]) == list(actual_dataset_dict["test"]) - - -def test_dataset_processor_decoder_only_style(): - """Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`.""" - _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() - gpt_processor = TextualizeProcessor( - has_encoder=False, eos_token=gpt2_tokenizer.eos_token - ) - raw_dataset_dicts = deepcopy(DATASET_DICTS) - gpt_modified_dataset_dicts = gpt_processor.process_dataset_dict( - INSTRUCTION, DATASET_DICTS - ) - # Ensure the dataset_dicts themselves are the same after processing. - for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): - assert list(raw["train"]) == list(origin["train"]) - if "val" in raw: - assert list(raw["val"]) == list(origin["val"]) - if "test" in raw: - assert list(raw["test"]) == list(origin["test"]) - # Check that the modified dataset dicts have the expected content - gpt_expected_dataset_dicts = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\nbaz<|endoftext|>", # noqa: E501 - "convert to text2text\nExample:\nbar\nLabel:\nqux<|endoftext|>", # noqa: E501 - ], - "model_output": ["baz<|endoftext|>", "qux<|endoftext|>"], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - "convert to text2text\nExample:\nbar\nLabel:\n", - ], - "model_output": ["baz", "qux"], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nspam\nLabel:\nham<|endoftext|>", # noqa: E501 - "convert to text2text\nExample:\neggs\nLabel:\nsau<|endoftext|>", # noqa: E501 - ], - "model_output": ["ham<|endoftext|>", "sau<|endoftext|>"], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nspam\nLabel:\n", - "convert to text2text\nExample:\neggs\nLabel:\n", - ], - "model_output": ["ham", "sau"], - } - ), - } - ), - ] - for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): - assert list(exp["train"]) == list(modified["train"]) - if "val" in exp: - assert list(exp["val"]) == list(modified["val"]) - if "test" in exp: - assert list(exp["test"]) == list(modified["test"]) - - -def test_unexpected_dataset_split(): - """Test the error handler for unexpercted dataset split.""" - with pytest.raises(ValueError) as exc_info: - _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() - gpt_processor = TextualizeProcessor( - has_encoder=False, eos_token=gpt2_tokenizer.eos_token - ) - _ = gpt_processor.process_dataset_dict( - INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_SPLIT - ) - assert str(exc_info.value) == ("Datset split must be in train/val/test.") - gc.collect() + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) -def test_unexpected_columns(): - """Test the error handler for unexpercted dataset columns.""" - with pytest.raises(ValueError) as exc_info: - _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() - gpt_processor = TextualizeProcessor( - has_encoder=False, eos_token=gpt2_tokenizer.eos_token - ) - _ = gpt_processor.process_dataset_dict( - INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS - ) - assert str(exc_info.value) == ( - "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 - ) - gc.collect() +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions(length=5).mock_completions, +) +def test_generator_with_filter_to_generate_datasetdict(mocked_generate_example): + """Test with filter methods to generate a DatasetDict.""" + # Initialize the PromptBasedDatasetGenerator with + # specific settings and limited API calls. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + max_api_calls=7, + ) + + # Generate the DatasetDict using the initialized generator. + generated_dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples={ + DatasetSplit.TRAIN: 4, + DatasetSplit.VAL: 4, + DatasetSplit.TEST: 2, + }, + ) + + # Assertions for API call count and dataset + # dictionaries matching the expected results. + assert mocked_generate_example.call_count == 5 + assert dataset_generator.api_call_counter == 7 -DATASET_DICTS_WITH_EMPTY_COLUMNS = [ - datasets.DatasetDict( + # Define the expected dataset dictionaries + # based on the given mock responses. + expected_dataset_dict = datasets.DatasetDict( { - "train": datasets.Dataset.from_dict( + "train": Dataset.from_dict( { - "input_col": ["foo", "", "test"], - "explain_col": ["abc", "", "xyz"], - "output_col": ["", "qux", "key"], + "input_col": ["1", "2", "3", "4"], + "explain_col": ["x", "x", "x", "x"], + "output_col": ["b", "a", "a", "c"], } ), - "test": datasets.Dataset.from_dict( + "val": Dataset.from_dict( { - "input_col": ["foo", ""], - "explain_col": ["abc", ""], - "output_col": ["baz", "qux"], + "input_col": ["1", "2"], + "explain_col": ["x", "x"], + "output_col": ["a", "a"], } ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( + "test": Dataset.from_dict( { - "input_col": ["", ""], - "explain_col": ["abc", "xyz"], - "output_col": ["ham", "sau"], + "input_col": [], + "explain_col": [], + "output_col": [], } ), } - ), -] - - -def test_empty_filter_t5_type(): - """Test that examples with empty input_col or output_col are discarded.""" - t5_processor = TextualizeProcessor(has_encoder=True) - t5_modified_dataset_dicts = t5_processor.process_dataset_dict( - INSTRUCTION, DATASET_DICTS_WITH_EMPTY_COLUMNS - ) - t5_expected_dataset_dicts = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\ntest\nLabel:\n", - ], - "model_output": ["key"], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - ], - "model_output": [ - "baz", - ], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [], - "model_output": [], - } - ), - } - ), - ] - for exp, modified in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): - assert list(exp["train"]) == list(modified["train"]) - if "val" in exp: - assert list(exp["val"]) == list(modified["val"]) - if "test" in exp: - assert list(exp["test"]) == list(modified["test"]) - - -def test_empty_filter_decoder_only_style(): - """Test the `process_dataset_dict` function of a GPT-type `TextualizeProcessor`.""" - _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() - gpt_processor = TextualizeProcessor( - has_encoder=False, eos_token=gpt2_tokenizer.eos_token - ) - gpt_modified_dataset_dicts = gpt_processor.process_dataset_dict( - INSTRUCTION, DATASET_DICTS_WITH_EMPTY_COLUMNS - ) - - # Check that the modified dataset dicts have the expected content - gpt_expected_dataset_dicts = [ - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\ntest\nLabel:\nkey<|endoftext|>", # noqa: E501 - ], - "model_output": ["key<|endoftext|>"], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - "convert to text2text\nExample:\nfoo\nLabel:\n", - ], - "model_output": ["baz"], - } - ), - } - ), - datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [], - "model_output": [], - } - ), - } - ), - ] - for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): - assert list(exp["train"]) == list(modified["train"]) - if "val" in exp: - assert list(exp["val"]) == list(modified["val"]) - if "test" in exp: - assert list(exp["test"]) == list(modified["test"]) - gc.collect() - - -GENERATED_DATASET = datasets.Dataset.from_dict( - { - "input_col": list(range(10000)), - "explain_col": ["a"] * 10000, - "output_col": list(range(10000, 20000)), - } + ) + + # Verify the generated DatasetDict matches the expected DatasetDict. + assert list(generated_dataset_dict["train"]) == list(expected_dataset_dict["train"]) + assert list(generated_dataset_dict["val"]) == list(expected_dataset_dict["val"]) + assert list(generated_dataset_dict["test"]) == list(expected_dataset_dict["test"]) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, ) +def test_generator_max_api_calls_dict(mocked_generate_example): + """Test generation of a dataset dict where we hit max api calls.""" + # Refresh the call_count and create a new limited_dataset_generator. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=False, + max_api_calls=13, + ) -RETRIEVED_TRAIN_DATASET = datasets.Dataset.from_dict( - { - "input_col": list(range(20000, 30000)), - "explain_col": ["a"] * 10000, - "output_col": list(range(30000, 40000)), + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, } + + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + # Since the max_api_calls is 13, the limited_dataset_generator cannot + # generate the whole dataset_dict and will call the API 13 times. + assert dataset_generator.api_call_counter == 13 + + # The train split has 50 examples, so it will call the API 10 times and call + # generate_batch_completion 2 times. + # The validation split has 24 examples, but there are only 3 API calls + # left, so it will call the API 3 times and call + # generate_batch_completion 1 time. + # The test split has 26 examples, but there are no more API calls left, + # so it will not call generate_batch_completion. + assert mocked_generate_example.call_count == 2 + 1 + 0 + + # Each API call returns 5 responses, and each response is a valid JSON. + # So the generated_dataset_dict will contain (50, 15, 0) examples. + assert len(dataset_dict["train"]) == 50 + assert len(dataset_dict["val"]) == 15 + assert len(dataset_dict["test"]) == 0 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_WRONG_KEY_EXAMPLE, ) +def test_wrong_key_example(mocked_generate_example): + """Test PromptBasedDatasetGenerator when the agent returns wrong keys.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.TRAIN + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec, num_examples, split + ) + assert mocked_generate_example.call_count == 3 + expected_dataset = Dataset.from_dict({"input_col": [], "explain_col":[], "output_col": []}) + assert list(expected_dataset) == list(generated_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_INVALID_JSON, +) +def test_invalid_json_response(mocked_generate_example): + """Test when the agent returns invalid JSON responses.""" + # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. + dataset_generator = PromptBasedDatasetGenerator(3, filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.VAL + dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + assert mocked_generate_example.call_count == 3 + expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) #noqa E501 + assert list(dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=UnknownGpt3Exception(), +) +def test_unexpected_examples_of_gpt(mocked_generate_example): + """Test PromptBasedDatasetGenerator when the agent returns unexpected examples.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. + with pytest.raises(UnknownGpt3Exception): + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.TEST + _ = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + assert mocked_generate_example.call_count == 1 + + +def test_filter_with_duplicate_inputs_unique_outputs(): + """Test filtering with duplicate inputs, unique outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [ + Example(input_col="apple", explain_col="a", output_col="A"), #noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), #noqa E501 + Example(input_col="apple", explain_col="c", output_col="E"), #noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), #noqa E501 + Example(input_col="apple", explain_col="e", output_col="D"), #noqa E501 + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + expected_examples = [ + Example(input_col="apple", explain_col="a", output_col="A"), #noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), #noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), #noqa E501 + ] + assert sorted(expected_examples) == sorted(filtered_examples) + + +def test_filter_duplicate_inputs_duplicate_outputs(): + """Test constructing a map with duplicate inputs and duplicate outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [ + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="C"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="apple", explain_col="a", output_col="G"), + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="orange", explain_col="a", output_col="O"), + Example(input_col="apple", explain_col="a", output_col="D"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="F"), + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + expected_examples = [ + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="O"), + ] + assert expected_examples == filtered_examples + + +def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): + """Test constructing a map with unique inputs and outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [ + Example(input_col="apple", explain_col="a", output_col="A"), + Example(input_col="banana", explain_col="a", output_col="B"), + Example(input_col="orange", explain_col="a", output_col="O"), + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + assert generated_examples == filtered_examples + + +def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): + """Test constructing a map with empty inputs and outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + assert generated_examples == filtered_examples + -DATASET_LIST = [GENERATED_DATASET, RETRIEVED_TRAIN_DATASET] +def test_compute_batch_size_with_limited_max_api_calls(): + """Test the batch size computation with limited max API calls.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(max_api_calls=28) + data_generator.api_call_counter = 26 + # Default batch size and responses_per_request are both 5. + # So each batch should contain 25 examples. + # At least (125 - 110) / 5 = 3 API calls needed to get + # more than 125 examples. -def test_raise_value_error_of_process_dataset_lists(): - """Test that the ValueError is correctly raised.""" - _, gpt2_tokenizer = create_gpt2_model_and_tokenizer() - gpt_processor = TextualizeProcessor( - has_encoder=False, eos_token=gpt2_tokenizer.eos_token + batch_size = data_generator.compute_batch_size( + num_examples=125, generated_dataset_size=110 ) - with pytest.raises(ValueError) as exc_info: - gpt_processor.process_dataset_lists(INSTRUCTION, DATASET_LIST, 0.8, 0.2) - error_info = exc_info.value.args[0] - assert ( - error_info - == "train_proportion 0.8 + val_proportion 0.2 must be less than 1." - ) + assert ( + batch_size + == data_generator.max_api_calls - data_generator.api_call_counter + == 28 - 26 + ) + + data_generator.api_call_counter = 20 + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) + assert ( + batch_size + == (125 - 110) / data_generator.responses_per_request + == (125 - 110) / 5 + ) + + data_generator.api_call_counter = 0 + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) + assert batch_size == data_generator.max_batch_size + - t5_processor = TextualizeProcessor(has_encoder=True) - with pytest.raises(ValueError) as exc_info: - t5_processor.process_dataset_lists(INSTRUCTION, DATASET_LIST, 0.8, 0.2) +def test_compute_batch_size_with_unlimited_max_api_calls(): + """Test the batch size computation with unlimited max API calls.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator() + # Default batch size and responses_per_request are both 5. + # So each batch should contain 25 examples. + + # At least (125 - 110) / 5 = 3 API calls needed to get + # more than 125 examples. + + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) + assert ( + batch_size + == (125 - 110) / data_generator.responses_per_request + == (125 - 110) / 5 + ) + + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) + assert batch_size == data_generator.max_batch_size == 5 + + +def test_extract_responses(): + """Test the extract_responses function of DatasetGenerator.""" + mock_completion_1 = MockCompletion() + mock_completion_1.choices = [ + {"message": {"content": '{"input": "1", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explain": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explain": "x", "output": "a"}'}}, + ] + mock_completion_2 = MockCompletion() + mock_completion_2.choices = [ + {"message": {"content": '{"input": "3", "explain": "x", "output": "a"}'}}, + # Note that the following choice miss the right quote of JSON. + # So it should be discarded. And will log a warning. + {"message": {"content": '{"input": "3", "explain": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3", "explain": "x", "output": "b"}'}}, + ] + mock_completion_3 = MockCompletion() + mock_completion_3.choices = [ + {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explain": "x", "output": "a"}'}}, + ] + # choices should be list of dicts. So mock_completion_4 + # is invalid. Which will be discarded and log a warning. + mock_completion_4 = MockCompletion() + mock_completion_4.choices = None + + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_1, mock_completion_2], generated_examples + ) + mock_warning.assert_called_once_with( + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + ) + # There are 5 valid examples. Each input + # and output will be logged once as info. + assert mock_info.call_count == 5 * 2 + + # The second choice in mock_completion_2 + # is invalid. So it should be discarded. + assert generated_examples == [ + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + ] + data_generator.extract_and_append_responses([mock_completion_3], generated_examples) + assert generated_examples == [ + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), + ] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_4], generated_examples + ) + mock_warning.assert_called_once_with( + "Error happened when parsing API completion: " + ) + mock_info.assert_not_called() + # The generated_examples should be the same. + assert generated_examples == [ + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="1", explain_col="x", output_col="b"), + Example(input_col="1", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), + ] + + +def test_extract_some_empty_responses(): + """Test the extract_responses function correctly handle empty responses.""" + mock_completion_1 = MockCompletion() + mock_completion_1.choices = [ + # Note that this choice's input is empty. So it should be discarded. + {"message": {"content": '{"input": "", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "5", "explain": "x", "output": "b"}'}}, + # Note that this choice's output is empty. So it should be discarded. + {"message": {"content": '{"input": "1", "explain": "x", "output": ""}'}}, + ] + mock_completion_2 = MockCompletion() + mock_completion_2.choices = [ + {"message": {"content": '{"input": "3", "explain": "x", "output": "a"}'}}, + # Note that the following choice misses the right quote of JSON. + # So it should be discarded. And will log a warning. + {"message": {"content": '{"input": "3", "explain": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3","explain": "x", "output": "b"}'}}, + ] + mock_completion_3 = MockCompletion() + mock_completion_3.choices = [ + {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5","explain": "x", "output": "a"}'}}, + ] + # choices should be list of dicts. So mock_completion_4 + # is invalid. Which will be discarded and log a warning. + mock_completion_4 = MockCompletion() + mock_completion_4.choices = None + + with tempfile.TemporaryDirectory() as cache_dir: + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator( + cache_root=cache_dir, filter_duplicated_examples=True + ) + generated_examples = [] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_1, mock_completion_2], generated_examples + ) + mock_warning.assert_called_once_with( + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explain": "x", "output": "a}\'}}' # noqa E501 + ) + # There are 3 valid examples in [mock_completion_1, + # mock_completion_2] Each input + # and output will be logged once as info. + # And there are 2 examples with empty + # input or output, which should be discarded + # and be logged as info. + assert mock_info.call_count == 3 * 2 + 2 + + # The second choice in mock_completion_2 + # is invalid. So it should be discarded. + assert generated_examples == [ + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + ] + data_generator.extract_and_append_responses( + [mock_completion_3], generated_examples + ) + assert generated_examples == [ + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), + ] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_4], generated_examples + ) + mock_warning.assert_called_once_with( + "Error happened when parsing API completion: " + ) + mock_info.assert_not_called() + # The generated_examples should be the same. + assert generated_examples == [ + Example(input_col="5", explain_col="x", output_col="b"), + Example(input_col="3", explain_col="x", output_col="a"), + Example(input_col="3", explain_col="x", output_col="b"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="4", explain_col="x", output_col="c"), + Example(input_col="5", explain_col="x", output_col="a"), + ] + + +def test_initialize_dataset_generator_with_dynamic_temperature(): + """Test the correct initialization of the dynamic temperature strategy.""" + with tempfile.TemporaryDirectory() as cache_dir: + os.environ["OPENAI_API_KEY"] = "fake_api_key" + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator( + cache_root=cache_dir, initial_temperature=-0.2 + ) error_info = exc_info.value.args[0] assert ( error_info - == "train_proportion 0.8 + val_proportion 0.2 must be less than 1." + == "initial_temperature must be >= 0, but self.initial_temperature=-0.2" ) + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator(cache_root=cache_dir, max_temperature=2.3) + error_info = exc_info.value.args[0] + assert ( + error_info + == "max_temperature must be <= 2,0, but self.max_temperature=2.3" + ) + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator( + cache_root=cache_dir, max_temperature=1.2, initial_temperature=1.5 + ) + error_info = exc_info.value.args[0] + assert ( + error_info + == "self.initial_temperature=1.5 must be <= self.max_temperature=1.2" + ) -def test_process_dataset_lists(): - """Test the `process_dataset_lists` function.""" - processor = TextualizeProcessor(has_encoder=True) - modified_dataset_dicts = processor.process_dataset_lists( - INSTRUCTION, DATASET_LIST, 0.6, 0.2 - ) - expected_modified_generated_dataset_dict = datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(6000) - ], - "model_output": [f"{output}" for output in range(10000, 16000)], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(6000, 8000) - ], - "model_output": [f"{output}" for output in range(16000, 18000)], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(8000, 10000) - ], - "model_output": [f"{output}" for output in range(18000, 20000)], - } - ), - } - ) - expected_modified_retrieved_dataset_dict = datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(20000, 26000) - ], - "model_output": [f"{output}" for output in range(30000, 36000)], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(26000, 28000) - ], - "model_output": [f"{output}" for output in range(36000, 38000)], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(28000, 30000) - ], - "model_output": [f"{output}" for output in range(38000, 40000)], - } - ), - } + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_dataset_generator_terminates(mocked_generate_example): + """Check to make sure that the dataset generator terminates.""" + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + dataset_generator = PromptBasedDatasetGenerator( + initial_temperature=0.3, + max_temperature=1.4, + responses_per_request=3, + max_api_calls=10000, + requests_per_minute=80, + filter_duplicated_examples=False, ) - for exp, modified in zip( - [ - expected_modified_generated_dataset_dict, - expected_modified_retrieved_dataset_dict, - ], - modified_dataset_dicts, - ): - assert list(exp["train"]) == list(modified["train"]) - if "val" in exp: - assert list(exp["val"]) == list(modified["val"]) - if "test" in exp: - assert list(exp["test"]) == list(modified["test"]) - - -def test_process_dataset_lists_with_maximum_example_num(): - """Test the maximum_example_num parameter.""" - processor = TextualizeProcessor(has_encoder=True) - modified_dataset_dicts = processor.process_dataset_lists( - INSTRUCTION, DATASET_LIST, 0.6, 0.2, {"train": 3000, "val": 500, "test": 1000} - ) - # Before applying the maximum_example_num, train_num = 6000, - # val_num = 2000, test_num = 2000. - # After applying the maximum_example_num, train_num = 3000, - # val_num = 2000, test_num = 2000. - expected_modified_generated_dataset_dict = datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(3000) - ], - "model_output": [f"{output}" for output in range(10000, 13000)], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(3000, 3500) - ], - "model_output": [f"{output}" for output in range(13000, 13500)], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(3500, 4500) - ], - "model_output": [f"{output}" for output in range(13500, 14500)], - } - ), - } + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec, 100, split=DatasetSplit.TRAIN ) - expected_modified_retrieved_dataset_dict = datasets.DatasetDict( - { - "train": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(20000, 23000) - ], - "model_output": [f"{output}" for output in range(30000, 33000)], - } - ), - "val": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(23000, 23500) - ], - "model_output": [f"{output}" for output in range(33000, 33500)], - } - ), - "test": datasets.Dataset.from_dict( - { - "model_input": [ - f"convert to text2text\nExample:\n{input}\nLabel:\n" - for input in range(23500, 24500) - ], - "model_output": [f"{output}" for output in range(33500, 34500)], - } - ), - } + generated_df = generated_dataset.to_pandas() + assert len(generated_dataset) == 100 + assert list(generated_df.columns) == ["input_col", "output_col"] + + +def test_generate_dataset_agent_switch(): + """Test if dataset generation can use a user-set API agent.""" + my_agent = MockAPIAgent( + default_content='{"input": "This is input.", "output": "This is an output."}' ) - for exp, modified in zip( - [ - expected_modified_generated_dataset_dict, - expected_modified_retrieved_dataset_dict, - ], - modified_dataset_dicts, - ): - assert list(exp["train"]) == list(modified["train"]) - if "val" in exp: - assert list(exp["val"]) == list(modified["val"]) - if "test" in exp: - assert list(exp["test"]) == list(modified["test"]) + with temp_setattr(api_tools, "default_api_agent", my_agent): + prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION) + dataset_generator = PromptBasedDatasetGenerator( + initial_temperature=0.3, + max_temperature=1.4, + responses_per_request=1, + max_api_calls=100, + requests_per_minute=80, + filter_duplicated_examples=False, + ) + dataset_generator.generate_dataset_split( + prompt_spec, 100, split=DatasetSplit.TRAIN + ) + # 100 outputs, and each batch has 5 outputs so 20 api calls + assert my_agent.generate_batch_call_counter == 20 From 057d42b017da19d3aa8c7c7966cc3b61207fb88e Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 08:10:56 -0500 Subject: [PATCH 20/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 14d077a6e..0f297c740 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -565,7 +565,7 @@ def test_wrong_key_example(mocked_generate_example): prompt_spec, num_examples, split ) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "explain_col":[], "output_col": []}) + expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) # noqa E501 assert list(expected_dataset) == list(generated_dataset) @@ -582,7 +582,7 @@ def test_invalid_json_response(mocked_generate_example): split = DatasetSplit.VAL dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) #noqa E501 + expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) # noqa E501 assert list(dataset) == list(expected_dataset) @@ -610,11 +610,11 @@ def test_filter_with_duplicate_inputs_unique_outputs(): os.environ["OPENAI_API_KEY"] = "fake_api_key" data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ - Example(input_col="apple", explain_col="a", output_col="A"), #noqa E501 - Example(input_col="banana", explain_col="b", output_col="B"), #noqa E501 - Example(input_col="apple", explain_col="c", output_col="E"), #noqa E501 - Example(input_col="orange", explain_col="d", output_col="O"), #noqa E501 - Example(input_col="apple", explain_col="e", output_col="D"), #noqa E501 + Example(input_col="apple", explain_col="a", output_col="A"), # noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), # noqa E501 + Example(input_col="apple", explain_col="c", output_col="E"), # noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), # noqa E501 + Example(input_col="apple", explain_col="e", output_col="D"), # noqa E501 ] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) expected_examples = [ From f24b344cf679bccd28e65cb3d92238edc8fec099 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 08:14:01 -0500 Subject: [PATCH 21/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 0f297c740..796365ab8 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -565,7 +565,9 @@ def test_wrong_key_example(mocked_generate_example): prompt_spec, num_examples, split ) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) # noqa E501 + expected_dataset = Dataset.from_dict( + {"input_col": [], "explain_col": [], "output_col": []} + ) # noqa E501 assert list(expected_dataset) == list(generated_dataset) @@ -582,7 +584,9 @@ def test_invalid_json_response(mocked_generate_example): split = DatasetSplit.VAL dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "explain_col": [], "output_col": []}) # noqa E501 + expected_dataset = Dataset.from_dict( + {"input_col": [], "explain_col": [], "output_col": []} + ) # noqa E501 assert list(dataset) == list(expected_dataset) @@ -618,9 +622,9 @@ def test_filter_with_duplicate_inputs_unique_outputs(): ] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) expected_examples = [ - Example(input_col="apple", explain_col="a", output_col="A"), #noqa E501 - Example(input_col="banana", explain_col="b", output_col="B"), #noqa E501 - Example(input_col="orange", explain_col="d", output_col="O"), #noqa E501 + Example(input_col="apple", explain_col="a", output_col="A"), # noqa E501 + Example(input_col="banana", explain_col="b", output_col="B"), # noqa E501 + Example(input_col="orange", explain_col="d", output_col="O"), # noqa E501 ] assert sorted(expected_examples) == sorted(filtered_examples) From 79af8e27e0edee7249dab22917465403a0880582 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:28:18 -0500 Subject: [PATCH 22/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 796365ab8..046d43cd9 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -29,41 +29,41 @@ MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', + content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', + content='{"input": "This is a great movie!", "explain":"x", "output": "1}', ) @@ -86,7 +86,7 @@ def test_generate_dataset(mocked_generate_example): # the length of the dataset is num_examples + 5, where 5 is the # default number of responses per API call. assert len(dataset) < num_examples + 5 - expected_columns = {"input_col", "output_col"} + expected_columns = {"input_col", "explain_col", "output_col"} assert set(dataset.column_names) == expected_columns return dataset @@ -116,7 +116,7 @@ def test_generate_dataset_dict(mocked_generate_example): # generated dataset is num_examples + 5, where # 5 is the default number of responses per API call. assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} + expected_columns = {"input_col", "explain_col", "output_col"} for dataset in dataset_dict.values(): assert set(dataset.column_names) == expected_columns From cb0bf9922028c2ed0b119c8bcb7e76156c938ae5 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:49:28 -0500 Subject: [PATCH 23/26] Update mock_api.py --- test_helpers/mock_api.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test_helpers/mock_api.py b/test_helpers/mock_api.py index abec258ed..6dbf0a9c2 100644 --- a/test_helpers/mock_api.py +++ b/test_helpers/mock_api.py @@ -100,15 +100,15 @@ def __init__(self, length: int = 4) -> None: self.current_index = 0 mock_completion_1 = MockCompletion() mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "1", "output": "c"}'}}, - {"message": {"content": '{"input": "2", "output": "a"}'}}, - {"message": {"content": '{"input": "2", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "2", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "2", "explanation": "x", "output": "b"}'}}, ] self.mock_completions.append( [ @@ -118,24 +118,24 @@ def __init__(self, length: int = 4) -> None: ) mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, ] self.mock_completions.append([mock_completion_3]) mock_completion_4 = MockCompletion() mock_completion_4.choices = [ - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, ] self.mock_completions.append([mock_completion_4]) mock_completion_5 = MockCompletion() mock_completion_5.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, ] self.mock_completions.append([mock_completion_5]) if length == 5: From 1af0de65d19c84c67d1bd9a2264dda281bc65e43 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 09:55:02 -0500 Subject: [PATCH 24/26] Update dataset_generator_test.py --- tests/dataset_generator_test.py | 66 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index 046d43cd9..ab5e0b38a 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -29,41 +29,41 @@ MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) MOCK_CLASSIFICATION_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1"}', ) MOCK_WRONG_KEY_EXAMPLE = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "label": "1"}', + content='{"input": "This is a great movie!", "explanation":"x", "label": "1"}', ) MOCK_INVALID_JSON = partial( mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "explain":"x", "output": "1}', + content='{"input": "This is a great movie!", "explanation":"x", "output": "1}', ) @@ -734,23 +734,23 @@ def test_extract_responses(): """Test the extract_responses function of DatasetGenerator.""" mock_completion_1 = MockCompletion() mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "explain": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "explain": "x", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, # Note that the following choice miss the right quote of JSON. # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "explain": "x", "output": "a}'}}, - {"message": {"content": '{"input": "3", "explain": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, ] mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, ] # choices should be list of dicts. So mock_completion_4 # is invalid. Which will be discarded and log a warning. @@ -767,11 +767,11 @@ def test_extract_responses(): [mock_completion_1, mock_completion_2], generated_examples ) mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explanation": "x", "output": "a}\'}}' # noqa E501 ) # There are 5 valid examples. Each input # and output will be logged once as info. - assert mock_info.call_count == 5 * 2 + assert mock_info.call_count == 5 * 3 # The second choice in mock_completion_2 # is invalid. So it should be discarded. @@ -821,24 +821,24 @@ def test_extract_some_empty_responses(): mock_completion_1 = MockCompletion() mock_completion_1.choices = [ # Note that this choice's input is empty. So it should be discarded. - {"message": {"content": '{"input": "", "explain": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "5", "explain": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "", "explanation": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "b"}'}}, # Note that this choice's output is empty. So it should be discarded. - {"message": {"content": '{"input": "1", "explain": "x", "output": ""}'}}, + {"message": {"content": '{"input": "1", "explanation": "x", "output": ""}'}}, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, # Note that the following choice misses the right quote of JSON. # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "explain": "x", "output": "a}'}}, - {"message": {"content": '{"input": "3","explain": "x", "output": "b"}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "a}'}}, + {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, ] mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "explain": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "5","explain": "x", "output": "a"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, ] # choices should be list of dicts. So mock_completion_4 # is invalid. Which will be discarded and log a warning. @@ -858,7 +858,7 @@ def test_extract_some_empty_responses(): [mock_completion_1, mock_completion_2], generated_examples ) mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explain": "x", "output": "a}\'}}' # noqa E501 + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "explanation": "x", "output": "a}\'}}' # noqa E501 ) # There are 3 valid examples in [mock_completion_1, # mock_completion_2] Each input @@ -866,7 +866,7 @@ def test_extract_some_empty_responses(): # And there are 2 examples with empty # input or output, which should be discarded # and be logged as info. - assert mock_info.call_count == 3 * 2 + 2 + assert mock_info.call_count == 3 * 3 + 2 # The second choice in mock_completion_2 # is invalid. So it should be discarded. @@ -959,13 +959,13 @@ def test_dataset_generator_terminates(mocked_generate_example): ) generated_df = generated_dataset.to_pandas() assert len(generated_dataset) == 100 - assert list(generated_df.columns) == ["input_col", "output_col"] + assert list(generated_df.columns) == ["input_col", "explain_col", "output_col"] def test_generate_dataset_agent_switch(): """Test if dataset generation can use a user-set API agent.""" my_agent = MockAPIAgent( - default_content='{"input": "This is input.", "output": "This is an output."}' + default_content='{"input": "This is input.", "explanation": "This is an explanation", "output": "This is an output."}' # noqa E501 ) with temp_setattr(api_tools, "default_api_agent", my_agent): prompt_spec = MockPromptSpec(TaskType.CLASSIFICATION) From 74213f8ea9be1a6a5829a1eb8bcdada86a07fe33 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 10:09:34 -0500 Subject: [PATCH 25/26] Update dataset_processor_test.py --- tests/dataset_processor_test.py | 39 +++++---------------------------- 1 file changed, 5 insertions(+), 34 deletions(-) diff --git a/tests/dataset_processor_test.py b/tests/dataset_processor_test.py index cb49b009d..ba7ec29d3 100644 --- a/tests/dataset_processor_test.py +++ b/tests/dataset_processor_test.py @@ -19,14 +19,12 @@ "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), @@ -37,14 +35,12 @@ "train": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), "val": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -60,22 +56,14 @@ datasets.DatasetDict( { "full": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } + {"input_col": ["foo", "bar"], "output_col": ["baz", "qux"]} ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output_col": ["ham", "sau"], - } + {"input_col": ["spam", "eggs"], "output_col": ["ham", "sau"]} ) } ), @@ -86,22 +74,14 @@ datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - { - "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], - "output_col": ["baz", "qux"], - } + {"input_col": ["foo", "bar"], "output_col": ["baz", "qux"]} ) } ), datasets.DatasetDict( { "train": datasets.Dataset.from_dict( - { - "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], - "output": ["ham", "sau"], - } + {"input_col": ["spam", "eggs"], "output": ["ham", "sau"]} ) } ), @@ -214,14 +194,12 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], "output_col": ["baz", "qux"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], "output_col": ["ham", "sau"], } ), @@ -232,14 +210,12 @@ def test_dataset_processor_with_numerical_column(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "bar"], - "explain_col": ["abc", "xyz"], "output_col": [0, 1], } ), "test": datasets.Dataset.from_dict( { "input_col": ["spam", "eggs"], - "explain_col": ["lmn", "opq"], "output_col": [1, 2], } ), @@ -387,7 +363,7 @@ def test_unexpected_columns(): INSTRUCTION, UNEXPECTED_DATASET_DICTS_WITH_WRONG_COLUMNS ) assert str(exc_info.value) == ( - "Example dictionary must have 'input_col', 'explain_col' and 'output_col' keys." # noqa E501 + "Example dictionary must have 'input_col' and 'output_col' keys." ) gc.collect() @@ -398,14 +374,12 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["foo", "", "test"], - "explain_col": ["abc", "", "xyz"], "output_col": ["", "qux", "key"], } ), "test": datasets.Dataset.from_dict( { "input_col": ["foo", ""], - "explain_col": ["abc", ""], "output_col": ["baz", "qux"], } ), @@ -416,7 +390,6 @@ def test_unexpected_columns(): "train": datasets.Dataset.from_dict( { "input_col": ["", ""], - "explain_col": ["abc", "xyz"], "output_col": ["ham", "sau"], } ), @@ -528,7 +501,6 @@ def test_empty_filter_decoder_only_style(): GENERATED_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(10000)), - "explain_col": ["a"] * 10000, "output_col": list(range(10000, 20000)), } ) @@ -536,7 +508,6 @@ def test_empty_filter_decoder_only_style(): RETRIEVED_TRAIN_DATASET = datasets.Dataset.from_dict( { "input_col": list(range(20000, 30000)), - "explain_col": ["a"] * 10000, "output_col": list(range(30000, 40000)), } ) From bc5de3d9e1653f9446aa875e6c9a990584f01896 Mon Sep 17 00:00:00 2001 From: VanyaBK <37258663+VanyaBK@users.noreply.github.com> Date: Fri, 8 Dec 2023 10:22:37 -0500 Subject: [PATCH 26/26] Update mock_api.py --- test_helpers/mock_api.py | 90 +++++++++++++++++++++++++++++++++------- 1 file changed, 75 insertions(+), 15 deletions(-) diff --git a/test_helpers/mock_api.py b/test_helpers/mock_api.py index 6dbf0a9c2..c6828e71e 100644 --- a/test_helpers/mock_api.py +++ b/test_helpers/mock_api.py @@ -100,15 +100,39 @@ def __init__(self, length: int = 4) -> None: self.current_index = 0 mock_completion_1 = MockCompletion() mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "explanation": "x", "output": "a"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "a"}' + } + }, ] mock_completion_2 = MockCompletion() mock_completion_2.choices = [ - {"message": {"content": '{"input": "1", "explanation": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "2", "explanation": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "2", "explanation": "x", "output": "b"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "2", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "2", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append( [ @@ -118,24 +142,60 @@ def __init__(self, length: int = 4) -> None: ) mock_completion_3 = MockCompletion() mock_completion_3.choices = [ - {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "explanation": "x", "output": "a"}'}}, - {"message": {"content": '{"input": "3", "explanation": "x", "output": "b"}'}}, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "a"}' + } + }, + { + "message": { + "content": '{"input": "3", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append([mock_completion_3]) mock_completion_4 = MockCompletion() mock_completion_4.choices = [ - {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "explanation": "x", "output": "b"}'}}, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, + { + "message": { + "content": '{"input": "1", "explanation": "x", "output": "b"}' + } + }, ] self.mock_completions.append([mock_completion_4]) mock_completion_5 = MockCompletion() mock_completion_5.choices = [ - {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "explanation": "x", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "explanation": "x", "output": "a"}'}}, + { + "message": { + "content": '{"input": "4", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "4", "explanation": "x", "output": "c"}' + } + }, + { + "message": { + "content": '{"input": "5", "explanation": "x", "output": "a"}' + } + }, ] self.mock_completions.append([mock_completion_5]) if length == 5: