Skip to content

Commit

Permalink
Merge pull request fastmachinelearning#83 from iksnagreb/feature/find…
Browse files Browse the repository at this point in the history
…_upstream_keep_visited

Add option to find_upstream to keep nodes visited even if not found
  • Loading branch information
maltanar authored Feb 23, 2024
2 parents fad667f + b3186cb commit fe4aa37
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/qonnx/core/modelwrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,16 +346,19 @@ def find_producer(self, tensor_name):
return x
return None

def find_upstream(self, tensor_name, finder_fxn):
def find_upstream(self, tensor_name, finder_fxn, keep_if_not_found=False):
"""Follow the producer chain upstream, calling finder_fxn on each upstream
node until it returns True or there are no nodes left. Returns the list
of nodes visited, or None if finder_fxn did not return True."""
of nodes visited, or None if finder_fxn did not return True. If
keep_if_not_found is specified, returns the list of nodes visited, even
if finder_fxn never returned True, i.e., if the search terminated at an
input or initializer."""
visit_list = []
current_tensor = tensor_name
while True:
current_producer = self.find_producer(current_tensor)
if current_producer is None:
return []
return visit_list if keep_if_not_found else []
else:
found = finder_fxn(current_producer)
visit_list.append(current_producer)
Expand All @@ -364,7 +367,7 @@ def find_upstream(self, tensor_name, finder_fxn):
elif len(current_producer.input) > 0:
current_tensor = current_producer.input[0]
else:
return None
return visit_list if keep_if_not_found else None

def find_consumer(self, tensor_name):
"""Finds and returns the node that consumes the tensor with given name.
Expand Down

0 comments on commit fe4aa37

Please sign in to comment.