Skip to content

Commit

Permalink
added visualization and black format
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Aug 10, 2023
1 parent e874c0d commit 1443ba9
Show file tree
Hide file tree
Showing 10 changed files with 92,843 additions and 97 deletions.
4 changes: 1 addition & 3 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def compute_num_features(features):
return num_features


def extract_feature_vec(
features,
):
def extract_feature_vec(features,):
feature_vec = None
for key in features.keys():
if feature_vec is None:
Expand Down
8 changes: 2 additions & 6 deletions src/deep_neurographs/graph_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,7 @@ def _extract_edges(self, query_id, query_xyz, nbs_xyz):
best_dist = self._get_best_edges(best_dist)
for nb_id in best_dist.keys():
self._add_edge(
query_id,
query_xyz,
nb_id,
best_xyz[nb_id],
best_dist[nb_id],
query_id, query_xyz, nb_id, best_xyz[nb_id], best_dist[nb_id]
)

def _get_best_edges(self, best_dist):
Expand Down Expand Up @@ -391,7 +387,7 @@ def num_edges(self):
"""
return self.number_of_edges()

def to_line_graph(self):
"""
Converts graph to a line graph.
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _init_edge(swc_dict=None, node=None):
edge["radius"].append(swc_dict["radius"][node])
edge["xyz"].append(swc_dict["xyz"][node])
return edge


def get_edge_attr(graph, edge, attr):
edge_data = graph.get_edge_data(*edge)
Expand Down
14 changes: 7 additions & 7 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ def generate_immutables(
neurograph.generate_immutables(swc_id, swc_dict)
return neurograph


def build_immutable_from_local(
neurograph,
swc_dir,
anisotropy=[1.0, 1.0, 1.0],
neurograph, swc_dir, anisotropy=[1.0, 1.0, 1.0]
):
"""
To do...
Expand Down Expand Up @@ -125,16 +124,17 @@ def read_mistake_log(bucket, file_key, s3_client):
hash_table = dict()
mistake_log = s3_utils.read_from_s3(bucket, file_key, s3_client)
for entry in mistake_log:
entry= entry.replace("[", "")
entry= entry.replace("]", "")
entry = entry.replace("[", "")
entry = entry.replace("]", "")
entry = entry.split(",")
entry = list(map(float, entry))

edge = (int(entry[0]), int(entry[1]))
xyz_coords = (entry[2:5], entry[5:])
hash_table[edge] = xyz_coords
return hash_table


"""
def build_supergraph(
bucket,
Expand Down Expand Up @@ -176,4 +176,4 @@ def create_nodes_from_swc(
graph.old_node_ids[node_id] = int(f.split(".")[0])
graph.upd_xyz_to_id(node_id)
return graph
"""
"""
6 changes: 3 additions & 3 deletions src/deep_neurographs/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
#self.conv2 = GCNConv(hidden_channels, out_channels)
# self.conv2 = GCNConv(hidden_channels, out_channels)

def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
#x = self.conv2(x, edge_index)
# x = self.conv2(x, edge_index)
return x

def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
102 changes: 61 additions & 41 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import plotly.tools as tls
import plotly.graph_objects as go
from more_itertools import zip_broadcast
from scipy.spatial import KDTree
from deep_neurographs import graph_utils as gutils, swc_utils, utils
from plotly.subplots import make_subplots

COLORS = list(mcolors.TABLEAU_COLORS.keys())
nCOLORS = len(COLORS)
Expand All @@ -28,7 +31,13 @@ class NeuroGraph(nx.Graph):
"""

def __init__(self, max_mutable_degree=5, max_mutable_edge_dist=150.0, prune=True, prune_depth=10):
def __init__(
self,
max_mutable_degree=5,
max_mutable_edge_dist=120.0,
prune=True,
prune_depth=10,
):
"""
Parameters
----------
Expand Down Expand Up @@ -95,11 +104,11 @@ def generate_immutables(self, swc_id, swc_dict):
self.add_edge(
node_id[i],
node_id[j],
xyz=np.array(edges[(i,j)]["xyz"]),
radius=np.array(edges[(i,j)]["radius"]),
xyz=np.array(edges[(i, j)]["xyz"]),
radius=np.array(edges[(i, j)]["radius"]),
swc_id=swc_id,
)
xyz_to_edge = dict([(xyz, edge) for xyz in edges[(i, j)]["xyz"]])
xyz_to_edge = dict((xyz, edge) for xyz in edges[(i, j)]["xyz"])
self.xyz_to_edge.update(xyz_to_edge)

# Update leafs and junctions
Expand All @@ -123,7 +132,7 @@ def generate_mutables(self):
for leaf in self.leafs:
xyz_leaf = self.nodes[leaf]["xyz"]
for xyz in self._get_mutables(leaf, xyz_leaf):
# Extract info on mutable connection
# Extract info on mutable connection
(i, j) = self.xyz_to_edge[xyz]
attrs = self.get_edge_data(i, j)

Expand Down Expand Up @@ -192,7 +201,7 @@ def _get_best_edges(self, dist, xyz):
"""
if len(dist.keys()) > self.max_mutable_degree:
keys = sorted(dist, key=dist.__getitem__)
return [xyz[key] for key in keys[:self.max_mutable_degree]]
return [xyz[key] for key in keys[: self.max_mutable_degree]]
else:
return list(xyz.values())

Expand All @@ -210,10 +219,10 @@ def add_immutable_node(self, edge, attrs, idx):
radius=attrs["radius"][idx],
swc_id=attrs["swc_id"],
)
self._add_edge((i, node_id), attrs, np.arange(0, idx+1))
self._add_edge((i, node_id), attrs, np.arange(0, idx + 1))
self._add_edge((node_id, j), attrs, np.arange(idx, len(attrs["xyz"])))
return node_id

def _add_edge(self, edge, attrs, idxs):
self.add_edge(
edge[0],
Expand Down Expand Up @@ -260,14 +269,6 @@ def _query_kdtree(self, query):
return self.kdtree.data[idxs]

# --- Visualization ---
def _init_figure(self):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
return fig, ax

def visualize_immutables(self):
"""
Parameters
Expand All @@ -282,36 +283,55 @@ def visualize_immutables(self):
None.
"""
_, ax = self._init_figure()
self._plot_edges(ax, self.immutable_edges)
plt.show()
data = [self._plot_nodes()]
data.extend(self._plot_edges(self.immutable_edges))
self._plot(data, "Immutable Graph")

def visualize_mutables(self):
_, ax = self._init_figure()
self._plot_edges(ax, self.immutable_edges, color="k")
self._plot_edges(ax, self.mutable_edges)
plt.show()

def _plot_node(self, ax, i, color="r"):
ax.scatter(
self.nodes[i]["xyz"][0],
self.nodes[i]["xyz"][1],
self.nodes[i]["xyz"][2],
color=color,
data = [self._plot_nodes()]
data.extend(self._plot_edges(self.immutable_edges, color="black"))
data.extend(self._plot_edges(self.mutable_edges))
self._plot(data, "Mutable Graph")

def _plot(self, data, title):
fig = go.Figure(data=data)
fig.update_layout(
title=title,
scene=dict(xaxis_title="X", yaxis_title="Y", zaxis_title="Z"),
)
fig.update_layout(
scene=dict(aspectmode="manual", aspectratio=dict(x=1, y=1, z=0.8)),
height=600,
)
fig.show()

def _plot_nodes(self):
xyz = nx.get_node_attributes(self, "xyz")
xyz = np.array(list(xyz.values()))
points = go.Scatter3d(
x=xyz[:, 0],
y=xyz[:, 1],
z=xyz[:, 2],
mode="markers",
name="Nodes"
marker=dict(size=3, color="red"),
)
return points

def _plot_edges(self, ax, edges, color=None):
def _plot_edges(self, edges, color=None):
traces = []
line = dict(color=color) if color is not None else dict()
for (i, j) in edges:
ax.plot(
self.edges[(i, j)]["xyz"][:, 0],
self.edges[(i, j)]["xyz"][:, 1],
self.edges[(i, j)]["xyz"][:, 2],
color=color,
trace = go.Scatter3d(
x=self.edges[(i, j)]["xyz"][:, 0],
y=self.edges[(i, j)]["xyz"][:, 1],
z=self.edges[(i, j)]["xyz"][:, 2],
mode="lines",
line=line,
)
self._plot_node(ax, i)
self._plot_node(ax, j)


traces.append(trace)
return traces

# --- Utils ---
def num_nodes(self):
"""
Expand Down Expand Up @@ -344,7 +364,7 @@ def num_edges(self):
"""
return self.number_of_edges()

def to_line_graph(self):
"""
Converts graph to a line graph.
Expand Down
7 changes: 1 addition & 6 deletions src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ def parse(raw_swc, anisotropy=[1.0, 1.0, 1.0]):
"""
# Initialize swc
swc_dict = {
"id": [],
"xyz": [],
"radius": [],
"pid": [],
}
swc_dict = {"id": [], "xyz": [], "radius": [], "pid": []}

# Parse raw data
min_id = np.inf
Expand Down
23 changes: 14 additions & 9 deletions src/deep_neurographs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ def train(model, optimizer, criterion, train_data):

# We perform a new round of negative sampling for every training epoch:
neg_edge_index = negative_sampling(
edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
edge_index=train_data.edge_index,
num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1),
method="sparse",
)

edge_label_index = torch.cat(
[train_data.edge_label_index, neg_edge_index],
dim=-1,
[train_data.edge_label_index, neg_edge_index], dim=-1
)
edge_label = torch.cat(
[
train_data.edge_label,
train_data.edge_label.new_zeros(neg_edge_index.size(1)),
],
dim=0,
)
edge_label = torch.cat([
train_data.edge_label,
train_data.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)

out = model.decode(z, edge_label_index).view(-1)
loss = criterion(out, edge_label)
Expand Down Expand Up @@ -87,4 +92,4 @@ def test(model, data):
z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)
"""
"""
Loading

0 comments on commit 1443ba9

Please sign in to comment.