Skip to content

Commit

Permalink
restructured : visualization tools
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 8, 2023
1 parent 8198f58 commit 363333e
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 143 deletions.
147 changes: 45 additions & 102 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
swc_path,
img_path=None,
label_mask=None,
optimize_depth=10,
optimize_depth=8,
optimize_proposals=False,
origin=None,
shape=None,
Expand Down Expand Up @@ -81,10 +81,15 @@ def __init__(
else:
self.bbox = None

def init_immutable_graph(self):
def init_immutable_graph(self, add_attrs=False):
immutable_graph = nx.Graph()
immutable_graph.add_nodes_from(self)
immutable_graph.add_edges_from(self.immutable_edges)
immutable_graph.add_nodes_from(self.nodes(data=add_attrs))
if add_attrs:
for edge in self.immutable_edges:
i, j = tuple(edge)
immutable_graph.add_edge(i, j, **self.get_edge_data(i, j))
else:
immutable_graph.add_edges_from(self.immutable_edges)
return immutable_graph

def init_predicted_graph(self):
Expand Down Expand Up @@ -262,11 +267,11 @@ def add_immutable_node(self, edge, attrs, idx):
radius=attrs["radius"][idx],
swc_id=attrs["swc_id"],
)
self._add_edge((i, node_id), attrs, np.arange(0, idx + 1))
self._add_edge((node_id, j), attrs, np.arange(idx, len(attrs["xyz"])))
self.__add_edge((i, node_id), attrs, np.arange(0, idx + 1))
self.__add_edge((node_id, j), attrs, np.arange(idx, len(attrs["xyz"])))
return node_id

def _add_edge(self, edge, attrs, idxs):
def __add_edge(self, edge, attrs, idxs):
self.add_edge(
edge[0],
edge[1],
Expand Down Expand Up @@ -334,13 +339,17 @@ def optimize_simple_edge(self, img, edge):
depth = self.optimize_depth

# Get image patch
hat_xyz_i = self.to_img(branch_i[depth])
hat_xyz_j = self.to_img(branch_j[depth])
idx_i = min(depth, branch_i.shape[0] - 1)
idx_j = min(depth, branch_j.shape[0] - 1)
hat_xyz_i = self.to_img(branch_i[idx_i])
hat_xyz_j = self.to_img(branch_j[idx_j])
patch_dims = geometry_utils.get_optimal_patch(hat_xyz_i, hat_xyz_j)
center = geometry_utils.get_midpoint(hat_xyz_i, hat_xyz_j).astype(int)
img_chunk = utils.get_chunk(img, center, patch_dims)

# Optimize
if (np.array(hat_xyz_i) < 0).any() or (np.array(hat_xyz_j) < 0).any():
return False
path = geometry_utils.shortest_path(
img_chunk,
utils.img_to_patch(hat_xyz_i, center, patch_dims),
Expand All @@ -349,7 +358,7 @@ def optimize_simple_edge(self, img, edge):
origin = utils.apply_anisotropy(self.origin, return_int=True)
path = geometry_utils.transform_path(path, origin, center, patch_dims)
self.edges[edge]["xyz"] = np.vstack(
[branch_i[depth], path, branch_j[depth]]
[branch_i[idx_i], path, branch_j[idx_j]]
)

def optimize_complex_edge(self, img, edge):
Expand All @@ -364,13 +373,13 @@ def optimize_complex_edge(self, img, edge):
#if len(branches) == 2:


def get_branch(self, xyz):
edge = self.xyz_to_edge[tuple(xyz)]
branch = self.edges[edge]["xyz"]
if not (branch[0] == xyz).all():
return np.flip(branch, axis=0)
def get_branch(self, xyz_or_node):
if type(xyz_or_node) is int:
nb = self.get_immutable_nbs(xyz_or_node)[0]
return self.orient_edge((xyz_or_node, nb), xyz_or_node)
else:
return branch
edge = self.xyz_to_edge[tuple(xyz_or_node)]
return deepcopy(self.edges[edge]["xyz"])

def get_branches(self, i):
branches = []
Expand Down Expand Up @@ -473,93 +482,20 @@ def is_target(
)
return True if aligned else False

# --- Visualization ---
def visualize_immutables(self, title="Immutable Graph", return_data=False):
"""
Parameters
----------
node_ids : list[int], optional
List of node ids to be plotted. The default is None.
edge_ids : list[tuple], optional
List of edge ids to be plotted. The default is None.
Returns
-------
None.
"""
data = self._plot_edges(self.immutable_edges)
data.append(self._plot_nodes())
if return_data:
return data
else:
utils.plot(data, title)

def visualize_proposals(self, title="Mutable Graph", return_data=False):
data = [self._plot_nodes()]
data.extend(self._plot_edges(self.immutable_edges, color="black"))
data.extend(self._plot_edges(self.mutable_edges))
if return_data:
return data
else:
utils.plot(data, title)

def visualize_targets(
self, target_graph=None, title="Target Edges", return_data=False
):
data = [self._plot_nodes()]
data.extend(self._plot_edges(self.immutable_edges, color="black"))
data.extend(self._plot_edges(self.target_edges))
if target_graph is not None:
data.extend(
target_graph._plot_edges(
target_graph.immutable_edges, color="blue"
)
)
if return_data:
return data
else:
utils.plot(data, title)

def visualize_subset(self, edges, target_graph=None, title=""):
data = [self._plot_nodes()]
data.extend(self._plot_edges(self.immutable_edges, color="black"))
data.extend(self._plot_edges(edges))
if target_graph is not None:
data.extend(
target_graph._plot_edges(
target_graph.immutable_edges, color="blue"
)
)
utils.plot(data, title)

def _plot_nodes(self):
xyz = nx.get_node_attributes(self, "xyz")
xyz = np.array(list(xyz.values()))
points = go.Scatter3d(
x=xyz[:, 0],
y=xyz[:, 1],
z=xyz[:, 2],
mode="markers",
name="Nodes",
marker=dict(size=3, color="red"),
)
return points

def _plot_edges(self, edges, color=None):
traces = []
line = dict(width=4) if color is None else dict(color=color, width=3)
for i, j in edges:
trace = go.Scatter3d(
x=self.edges[(i, j)]["xyz"][:, 0],
y=self.edges[(i, j)]["xyz"][:, 1],
z=self.edges[(i, j)]["xyz"][:, 2],
mode="lines",
line=line,
name="({},{})".format(i, j),
# --- Generate reconstructions post-inference
def get_reconstruction(self, proposals, upd_self=False):
reconstruction = self.init_immutable_graph(add_attrs=True)
for edge in proposals:
i, j = tuple(edge)
r_i = self.nodes[i]["radius"]
r_j = self.nodes[j]["radius"]
reconstruction.add_edge(
i,
j,
xyz=self.edges[i, j]["xyz"],
radius=[r_i, r_j],
)
traces.append(trace)
return traces
return reconstruction

# --- Utils ---
def num_nodes(self):
Expand Down Expand Up @@ -633,6 +569,13 @@ def immutable_degree(self, i):
degree += 1
return degree

def get_immutable_nbs(self, i):
nbs = []
for j in self.neighbors(i):
if frozenset((i, j)) in self.immutable_edges:
nbs.append(j)
return nbs

def compute_length(self, edge, metric="l2"):
i, j = tuple(edge)
xyz_1, xyz_2 = self.get_edge_attr("xyz", i, j)
Expand Down
41 changes: 0 additions & 41 deletions src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,47 +285,6 @@ def write_json(path, contents):
json.dump(contents, f)


# --- plot utils ---
def plot(data, title):
fig = go.Figure(data=data)
fig.update_layout(
plot_bgcolor="white",
title=title,
scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
)
fig.update_layout(
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)),
width=1200,
height=600,
)
fig.show()


def subplot(data1, data2, title):
fig = make_subplots(
rows=1, cols=2, specs=[[{"type": "scene"}, {"type": "scene"}]]
)
fig.add_trace(data1, row=1, col=1)
fig.add_trace(data2, row=1, col=2)
fig.update_layout(title_text=title, showlegend=True)

fig.update_xaxes(row=1, col=1, matches="y", showgrid=False)
fig.update_yaxes(row=1, col=1, matches="x", showgrid=False)
fig.update_layout(
scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1)
)

# Update the size of the second subplot
fig.update_xaxes(row=1, col=2, matches="y")
fig.update_yaxes(row=1, col=2, matches="x")
fig.update_layout(
scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1)
)

fig.update_layout(width=1200, height=800)
fig.show()


# --- coordinate conversions ---
def world_to_img(neurograph, node_or_xyz):
if type(node_or_xyz) == int:
Expand Down
115 changes: 115 additions & 0 deletions src/deep_neurographs/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Created on Sat July 15 9:00:00 2023
@author: Anna Grim
@email: anna.grim@alleninstitute.org
"""

import networkx as nx
import numpy as np
import plotly.graph_objects as go


def visualize_connected_components(graph):
pass


def visualize_immutables(graph, title="Immutable Graph"):
data = plot_edges(graph, graph.immutable_edges)
data.append(plot_nodes(graph))
plot(data, title)


def visualize_proposals(graph, title="Edge Proposals"):
visualize_subset(graph, graph.mutable_edges, title=title)


def visualize_targets(graph, target_graph=None, title="Target Edges"):
visualize_subset(
graph,
graph.target_edges,
target_graph=target_graph,
title=title,
)


def visualize_subset(graph, edges, target_graph=None, title=""):
data = plot_edges(graph, graph.immutable_edges, color="black")
data.extend(plot_edges(graph, edges))
data.append(plot_nodes(graph))
if target_graph:
edges = target_graph.immutable_edges
data.extend(plot_edges(target_graph, edges, color="blue"))
plot(data, title)


# utils
def plot_nodes(graph):
xyz = nx.get_node_attributes(graph, "xyz")
xyz = np.array(list(xyz.values()))
return go.Scatter3d(
x=xyz[:, 0],
y=xyz[:, 1],
z=xyz[:, 2],
mode="markers",
name="Nodes",
marker=dict(size=3, color="red"),
)


def plot_edges(graph, edges, color=None):
traces = []
line = dict(width=4) if color is None else dict(color=color, width=3)
for i, j in edges:
trace = go.Scatter3d(
x=graph.edges[(i, j)]["xyz"][:, 0],
y=graph.edges[(i, j)]["xyz"][:, 1],
z=graph.edges[(i, j)]["xyz"][:, 2],
mode="lines",
line=line,
name="({},{})".format(i, j),
)
traces.append(trace)
return traces


def plot(data, title):
fig = go.Figure(data=data)
fig.update_layout(
plot_bgcolor="white",
title=title,
scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
)
fig.update_layout(
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=1)),
width=1200,
height=600,
)
fig.show()


def subplot(data1, data2, title):
fig = make_subplots(
rows=1, cols=2, specs=[[{"type": "scene"}, {"type": "scene"}]]
)
fig.add_trace(data1, row=1, col=1)
fig.add_trace(data2, row=1, col=2)
fig.update_layout(title_text=title, showlegend=True)

fig.update_xaxes(row=1, col=1, matches="y", showgrid=False)
fig.update_yaxes(row=1, col=1, matches="x", showgrid=False)
fig.update_layout(
scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1)
)

# Update the size of the second subplot
fig.update_xaxes(row=1, col=2, matches="y")
fig.update_yaxes(row=1, col=2, matches="x")
fig.update_layout(
scene_aspectmode="manual", scene_aspectratio=dict(x=1, y=1, z=1)
)

fig.update_layout(width=1200, height=800)
fig.show()

0 comments on commit 363333e

Please sign in to comment.