Skip to content

Commit

Permalink
upd : extracting irreducibles
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jan 11, 2024
1 parent 83c3e4e commit be304c0
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 188 deletions.
4 changes: 1 addition & 3 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def init_graphs(self, swc_dir):
# Construct Graph
path = os.path.join(swc_dir, f)
swc_dict = swc_utils.parse_local_swc(path)
graph, xyz_to_node = swc_utils.file_to_graph(
swc_dict, set_attrs=True, return_dict=True
)
graph, xyz_to_node = swc_utils.to_graph(swc_dict, set_attrs=True)

# Store
xyz_to_id = dict(zip_broadcast(swc_dict["xyz"], f))
Expand Down
250 changes: 141 additions & 109 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,143 +5,175 @@
@email: anna.grim@alleninstitute.org
Routines for working with graphs.
Routines that extract the irreducible components of a graph.
"""

from copy import deepcopy
from random import sample

import networkx as nx
import numpy as np

from deep_neurographs import geometry_utils, swc_utils, utils


def get_irreducibles(swc_dict, prune=True, depth=16, smooth=True):
"""
Gets irreducible components of the graph stored in "swc_dict". The
irreducible components consist of the leaf and junction nodes along with
the edges among this set of nodes.
Parameters
----------
swc_dict : dict
Contents of an swc file.
prune : True
Indication of whether to prune short branches.
depth : int
Path length that determines whether a branch is short.
smooth : bool
Indication of whether to smooth each branch.
Returns
-------
leafs : set
Nodes with degreee 1.
junctions : set
Nodes with degree > 2.
edges : dict
Set of edges connecting nodes in leafs and junctions. The keys are
pairs of nodes connected by an edge and values are a dictionary of
attributes.
"""
# Initializations
dense_graph = swc_utils.to_graph(swc_dict)
if prune:
dense_graph = prune_short_branches(dense_graph, depth)

# Extract irreducibles
leafs, junctions = get_irreducible_nodes(dense_graph)
source = sample(leafs, 1)[0]
root = None
edges = dict()
nbs = dict()
for (i, j) in nx.dfs_edges(dense_graph, source=source):
# Check if start of path is valid
if root is None:
root = i
attrs = __init_edge_attrs(swc_dict, i)

# Visit j
attrs = __upd_edge_attrs(swc_dict, attrs, j)
if j in leafs or j in junctions:
if smooth:
swc_dict, edges = __smooth_branch(
swc_dict, attrs, edges, nbs, root, j
)
else:
edges[(root, j)] = attrs
nbs = append_value(nbs, root, j)
nbs = append_value(nbs, j, root)
root = None
return leafs, junctions, edges

from deep_neurographs import swc_utils, utils

def get_irreducible_nodes(graph):
"""
Gets irreducible nodes (i.e. leafs and junctions) of a graph.
def extract_irreducible_graph(swc_dict, prune=True, prune_depth=16):
graph = swc_utils.to_graph(swc_dict)
leafs, junctions = get_irreducibles(graph)
irreducible_edges, leafs = extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=prune, prune_depth=prune_depth
)
if prune:
irreducible_edges, junctions = check_irreducibility(
junctions, irreducible_edges
)
return leafs, junctions, irreducible_edges
Parameters
----------
graph : networkx.Graph
Graph to be searched.
Returns
-------
leafs : set
Nodes with degreee 1.
junctions : set
Nodes with degree > 2.
def get_irreducibles(graph):
leafs = []
junctions = []
"""
leafs = set()
junctions = set()
for i in graph.nodes:
if graph.degree[i] == 1:
leafs.append(i)
leafs.add(i)
elif graph.degree[i] > 2:
junctions.append(i)
junctions.add(i)
return leafs, junctions


def extract_irreducible_edges(
graph, leafs, junctions, swc_dict, prune=True, prune_depth=16
):
root = None
irreducible_edges = dict()
for (i, j) in nx.dfs_edges(graph, source=leafs[0]):
# Check start of path is valid
if root is None:
root = i
edge = _init_edge(swc_dict=swc_dict, node=i)
path_length = 0
def prune_short_branches(graph, depth):
remove_nodes = []
for leaf in get_leafs(graph):
remove_nodes.extend(inspect_branch(graph, leaf, depth))
graph.remove_nodes_from(remove_nodes)
return graph

# Add to path
edge["radius"].append(swc_dict["radius"][j])
edge["xyz"].append(swc_dict["xyz"][j])
path_length += 1

# Check whether to end path
if j in leafs or j in junctions:
if prune and path_length <= prune_depth:
condition1 = j in leafs and root in junctions
condition2 = root in leafs and j in junctions
if condition1 or condition2:
leafs.remove(j if condition1 else root)
else:
irreducible_edges[(root, j)] = edge
else:
irreducible_edges[(root, j)] = edge
root = None
return irreducible_edges, leafs


def check_irreducibility(junctions, irreducible_edges):
graph = nx.Graph()
graph.add_edges_from(irreducible_edges.keys())
nx.set_edge_attributes(graph, irreducible_edges)
for j in junctions:
if j not in graph.nodes:
junctions.remove(j)
elif graph.degree[j] == 2:
# Get join edges
nbs = list(graph.neighbors(j))
edge1 = graph.get_edge_data(j, nbs[0])
edge2 = graph.get_edge_data(j, nbs[1])
edge = join_edges(edge1, edge2)

# Update irreducible edges
junctions.remove(j)
irreducible_edges = utils.remove_key(
irreducible_edges, (j, nbs[0])
)
irreducible_edges = utils.remove_key(
irreducible_edges, (j, nbs[1])
)
irreducible_edges[tuple(nbs)] = edge
def inspect_branch(graph, leaf, depth):
path = [leaf]
for (i, j) in nx.dfs_edges(graph, source=leaf, depth_limit=depth):
if graph.degree(j) > 2:
return path
elif graph.degree(j) == 2:
path.append(j)
return []


graph.remove_edge(j, nbs[0])
graph.remove_edge(j, nbs[1])
graph.remove_node(j)
graph.add_edge(*tuple(nbs), xyz=edge["xyz"], radius=edge["radius"])
if graph.degree[nbs[0]] > 2:
junctions.append(nbs[0])
def get_leafs(graph):
return [i for i in graph.nodes if graph.degree[i] == 1]

if graph.degree[nbs[1]] > 2:
junctions.append(nbs[1])

return irreducible_edges, junctions
def __smooth_branch(swc_dict, attrs, edges, nbs, root, j):
attrs["xyz"] = geometry_utils.smooth_branch(np.array(attrs["xyz"]))
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, root, 0)
swc_dict, edges = upd_xyz(swc_dict, attrs, edges, nbs, j, -1)
edges[(root, j)] = attrs
return swc_dict, edges


def join_edges(edge1, edge2):
# Last point in edge1 must connect to first point in edge2
if edge1["xyz"][0] == edge2["xyz"][0]:
edge1 = reverse_edge(edge1)
elif edge1["xyz"][-1] == edge2["xyz"][-1]:
edge2 = reverse_edge(edge2)
elif edge1["xyz"][0] == edge2["xyz"][-1]:
edge1 = reverse_edge(edge1)
edge2 = reverse_edge(edge2)
edge = {
"xyz": edge1["xyz"] + edge2["xyz"][1:],
"radius": edge1["radius"] + edge2["radius"],
}
return edge
def upd_xyz(swc_dict, attrs, edges, nbs, i, start_or_end):
if i in nbs.keys():
for j in nbs[i]:
key = (i, j) if (i, j) in edges.keys() else (j, i)
edges = upd_branch_endpoint(
edges, key, swc_dict["xyz"][i], attrs["xyz"][start_or_end]
)
swc_dict["xyz"][i] = attrs["xyz"][start_or_end]
return swc_dict, edges


def reverse_edge(edge):
edge["xyz"].reverse()
edge["radius"].reverse()
return edge
def append_value(my_dict, key, value):
if key in my_dict.keys():
my_dict[key].append(value)
else:
my_dict[key] = [value]
return my_dict


def _init_edge(swc_dict=None, node=None):
edge = {"radius": [], "xyz": []}
if node is not None:
edge["radius"].append(swc_dict["radius"][node])
edge["xyz"].append(swc_dict["xyz"][node])
return edge
def upd_branch_endpoint(edges, key, old_xyz, new_xyz):
if all(edges[key]["xyz"][0] == old_xyz):
edges[key]["xyz"][0] = new_xyz
else:
edges[key]["xyz"][-1] = new_xyz
return edges


# -- attribute utils --
def __init_edge_attrs(swc_dict, i):
return {"radius": [swc_dict["radius"][i]], "xyz": [swc_dict["xyz"][i]]}


def __upd_edge_attrs(swc_dict, attrs, i):
attrs["radius"].append(swc_dict["radius"][i])
attrs["xyz"].append(swc_dict["xyz"][i])
return attrs


def get_edge_attr(graph, edge, attr):
edge_data = graph.get_edge_data(*edge)
return edge_data[attr]


def is_leaf(graph, i):
nbs = [j for j in graph.neighbors(i)]
return True if len(nbs) == 1 else False
Loading

0 comments on commit be304c0

Please sign in to comment.