Skip to content

Commit

Permalink
Decrease the number of annotation lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
renatahodovan committed Oct 28, 2024
1 parent 9364362 commit c6dd9c1
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions grammarinator/tool/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,15 @@ def regenerate_rule(self):

# Filter items from the nodes of the selected tree that can be regenerated
# within the current maximum depth and token limit (except immutable nodes).
root_token_counts = annot.token_counts[root]
options = [node for nodes in annot.rules_by_name.values() for node in nodes
if (node.parent is not None
and annot.node_levels[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).depth < self._limit.depth
and annot.token_counts[root] - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens)]
and root_token_counts - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens)]
if options:
mutated_node = random.choice(options)
reserve = RuleSize(depth=annot.node_levels[mutated_node],
tokens=annot.token_counts[root] - annot.token_counts[mutated_node])
tokens=root_token_counts - annot.token_counts[mutated_node])
mutated_node = mutated_node.replace(self.generate(rule=mutated_node.name, reserve=reserve))
return mutated_node.root

Expand Down Expand Up @@ -350,13 +351,16 @@ def replace_node(self):
common_types = sorted(set(recipient_lookup.keys()) & set(donor_lookup.keys()))

recipient_options = [(rule_name, node) for rule_name in common_types for node in recipient_lookup[rule_name] if node.parent]
recipient_root_token_counts = recipient_annot.token_counts[recipient_root]
# Shuffle suitable nodes with sample.
for rule_name, recipient_node in random.sample(recipient_options, k=len(recipient_options)):
donor_options = donor_lookup[rule_name]
recipient_node_level = recipient_annot.node_levels[recipient_node]
recipient_node_tokens = recipient_annot.token_counts[recipient_node]
for donor_node in random.sample(donor_options, k=len(donor_options)):
# Make sure that the output tree won't exceed the depth limit.
if (recipient_annot.node_levels[recipient_node] + donor_annot.node_depths[donor_node] <= self._limit.depth
and recipient_annot.token_counts[recipient_root] - recipient_annot.token_counts[recipient_node] + donor_annot.token_counts[donor_node] < self._limit.tokens):
if (recipient_node_level + donor_annot.node_depths[donor_node] <= self._limit.depth
and recipient_root_token_counts - recipient_node_tokens + donor_annot.token_counts[donor_node] < self._limit.tokens):
recipient_node.replace(donor_node)
return recipient_root

Expand All @@ -379,12 +383,14 @@ def insert_quantified(self):

common_types = sorted(set(recipient_annot.quants_by_name.keys()) & set(donor_annot.quants_by_name.keys()))
recipient_options = [(name, node) for name in common_types for node in recipient_annot.quants_by_name[name] if len(node.children) < node.stop]
recipient_root_token_counts = recipient_annot.token_counts[recipient_root]
for rule_name, recipient_node in random.sample(recipient_options, k=len(recipient_options)):
recipient_node_level = recipient_annot.node_levels[recipient_node]
donor_options = [quantified for quantifier in donor_annot.quants_by_name[rule_name] for quantified in quantifier.children]
for donor_node in random.sample(donor_options, k=len(donor_options)):
# Make sure that the output tree won't exceed the depth and token limits.
if (recipient_annot.node_levels[recipient_node] + donor_annot.node_depths[donor_node] <= self._limit.depth
and recipient_annot.token_counts[recipient_root] + donor_annot.token_counts[donor_node] < self._limit.tokens):
if (recipient_node_level + donor_annot.node_depths[donor_node] <= self._limit.depth
and recipient_root_token_counts + donor_annot.token_counts[donor_node] < self._limit.tokens):
recipient_node.insert_child(random.randint(0, len(recipient_node.children)), donor_node)
return recipient_root

Expand Down Expand Up @@ -418,8 +424,9 @@ def replicate_quantified(self):
"""
root, annot = self._select_individual()
root_options = [node for node in annot.quants if node.stop > len(node.children)]
recipient_root_token_counts = annot.token_counts[root]
node_options = [child for root in root_options for child in root.children if
annot.token_counts[root] + annot.token_counts[child] <= self._limit.tokens]
recipient_root_token_counts + annot.token_counts[child] <= self._limit.tokens]
if node_options:
node_to_repeat = random.choice(node_options)
node_to_repeat.parent.insert_child(idx=random.randint(0, len(node_to_repeat.parent.children)), node=deepcopy(node_to_repeat))
Expand Down

0 comments on commit c6dd9c1

Please sign in to comment.