diff --git a/grammarinator/tool/generator.py b/grammarinator/tool/generator.py index 817923f..5ef303d 100644 --- a/grammarinator/tool/generator.py +++ b/grammarinator/tool/generator.py @@ -311,14 +311,15 @@ def regenerate_rule(self): # Filter items from the nodes of the selected tree that can be regenerated # within the current maximum depth and token limit (except immutable nodes). + root_token_counts = annot.token_counts[root] options = [node for nodes in annot.rules_by_name.values() for node in nodes if (node.parent is not None and annot.node_levels[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).depth < self._limit.depth - and annot.token_counts[root] - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens)] + and root_token_counts - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens)] if options: mutated_node = random.choice(options) reserve = RuleSize(depth=annot.node_levels[mutated_node], - tokens=annot.token_counts[root] - annot.token_counts[mutated_node]) + tokens=root_token_counts - annot.token_counts[mutated_node]) mutated_node = mutated_node.replace(self.generate(rule=mutated_node.name, reserve=reserve)) return mutated_node.root @@ -350,13 +351,16 @@ def replace_node(self): common_types = sorted(set(recipient_lookup.keys()) & set(donor_lookup.keys())) recipient_options = [(rule_name, node) for rule_name in common_types for node in recipient_lookup[rule_name] if node.parent] + recipient_root_token_counts = recipient_annot.token_counts[recipient_root] # Shuffle suitable nodes with sample. for rule_name, recipient_node in random.sample(recipient_options, k=len(recipient_options)): donor_options = donor_lookup[rule_name] + recipient_node_level = recipient_annot.node_levels[recipient_node] + recipient_node_tokens = recipient_annot.token_counts[recipient_node] for donor_node in random.sample(donor_options, k=len(donor_options)): # Make sure that the output tree won't exceed the depth limit. - if (recipient_annot.node_levels[recipient_node] + donor_annot.node_depths[donor_node] <= self._limit.depth - and recipient_annot.token_counts[recipient_root] - recipient_annot.token_counts[recipient_node] + donor_annot.token_counts[donor_node] < self._limit.tokens): + if (recipient_node_level + donor_annot.node_depths[donor_node] <= self._limit.depth + and recipient_root_token_counts - recipient_node_tokens + donor_annot.token_counts[donor_node] < self._limit.tokens): recipient_node.replace(donor_node) return recipient_root @@ -379,12 +383,14 @@ def insert_quantified(self): common_types = sorted(set(recipient_annot.quants_by_name.keys()) & set(donor_annot.quants_by_name.keys())) recipient_options = [(name, node) for name in common_types for node in recipient_annot.quants_by_name[name] if len(node.children) < node.stop] + recipient_root_token_counts = recipient_annot.token_counts[recipient_root] for rule_name, recipient_node in random.sample(recipient_options, k=len(recipient_options)): + recipient_node_level = recipient_annot.node_levels[recipient_node] donor_options = [quantified for quantifier in donor_annot.quants_by_name[rule_name] for quantified in quantifier.children] for donor_node in random.sample(donor_options, k=len(donor_options)): # Make sure that the output tree won't exceed the depth and token limits. - if (recipient_annot.node_levels[recipient_node] + donor_annot.node_depths[donor_node] <= self._limit.depth - and recipient_annot.token_counts[recipient_root] + donor_annot.token_counts[donor_node] < self._limit.tokens): + if (recipient_node_level + donor_annot.node_depths[donor_node] <= self._limit.depth + and recipient_root_token_counts + donor_annot.token_counts[donor_node] < self._limit.tokens): recipient_node.insert_child(random.randint(0, len(recipient_node.children)), donor_node) return recipient_root @@ -418,8 +424,9 @@ def replicate_quantified(self): """ root, annot = self._select_individual() root_options = [node for node in annot.quants if node.stop > len(node.children)] + recipient_root_token_counts = annot.token_counts[root] node_options = [child for root in root_options for child in root.children if - annot.token_counts[root] + annot.token_counts[child] <= self._limit.tokens] + recipient_root_token_counts + annot.token_counts[child] <= self._limit.tokens] if node_options: node_to_repeat = random.choice(node_options) node_to_repeat.parent.insert_child(idx=random.randint(0, len(node_to_repeat.parent.children)), node=deepcopy(node_to_repeat))