Skip to content

Commit

Permalink
major upd : evaluation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 7, 2023
1 parent 78f495e commit 8198f58
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 136 deletions.
4 changes: 2 additions & 2 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ def random_split(train_set, train_ratio=0.85):
return torch_data.random_split(train_set, [train_set_size, valid_set_size])


def eval_network(X, model, threshold=0.5):
def eval_network(X, model):
model.eval()
X = torch.tensor(X, dtype=torch.float32)
with torch.no_grad():
y_pred = sigmoid(model.net(X))
return np.array(y_pred > threshold, dtype=int)
return np.array(y_pred)


# Lightning Module
Expand Down
269 changes: 177 additions & 92 deletions src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
@author: Anna Grim
@email: anna.grim@alleninstitute.org
Evaluates performance of edge classifier.
Evaluates performance of edge classifiation model.
"""
import numpy as np

STATS_LIST = [
METRICS_LIST = [
"precision",
"recall",
"f1",
Expand All @@ -18,108 +18,193 @@
]


def run_evaluation(
target_graphs, pred_graphs, y_pred, block_to_idxs, idx_to_edge, blocks
):
stats = dict([(s, []) for s in STATS_LIST])
stats_by_type = {
"simple": dict([(s, []) for s in STATS_LIST]),
"complex": dict([(s, []) for s in STATS_LIST]),
def run_evaluation(neurographs, blocks, pred_edges):
"""
Runs an evaluation on the accuracy of the predictions generated by an edge
classication model.
Parameters
----------
neurographs : list[NeuroGraph]
Predicted neurographs.
y_pred : numpy.ndarray
Binary predictions of edges generated by classifcation model.
blocks_to_idxs : dict
Dictionary that stores which indices in "y_pred" correspond to edges
in "neurographs[block_id]".
idx_to_edge : dict
Dictionary that stores the correspondence between an index from
"y_pred" and edge from "neurographs[block_id]" for some "block_id".
blocks : list[str]
List of block_ids that indicate which predictions to evaluate.
Returns
-------
stats : dict
Dictionary that stores the accuracy of the edge classification model
on all edges (i.e. "Overall"), simple edges, and complex edges. The
metrics contained in this dictionary are identical to "METRICS_LIST"].
"""
stats = {
"Overall": dict([(metric, []) for metric in METRICS_LIST]),
"Simple": dict([(metric, []) for metric in METRICS_LIST]),
"Complex": dict([(metric, []) for metric in METRICS_LIST]),
}
print(blocks)
for block_id in blocks:
# Get predicted edges
pred_edges = get_predictions(
block_to_idxs[block_id], idx_to_edge, y_pred
# Compute accuracy
overall_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].mutable_edges,
pred_edges[block_id]
)

# Overall performance
num_fixes, num_mistakes = __reconstruction_stats(
target_graphs[block_id], pred_graphs[block_id], pred_edges
simple_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].get_simple_proposals(),
pred_edges[block_id],
)
stats["# splits fixed"].append(num_fixes)
stats["# merges created"].append(num_mistakes)

# In-depth performance
simple_stats, complex_stats = __reconstruction_type_stats(
target_graphs[block_id], pred_graphs[block_id], pred_edges
complex_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].get_complex_proposals(),
pred_edges[block_id],
)
if True:
print("simple stats:", simple_stats)
print("complex stats:", complex_stats)
print("")
for key in STATS_LIST:
stats_by_type["simple"][key].append(simple_stats[key])
stats_by_type["complex"][key].append(complex_stats[key])
return stats, stats_by_type

# Store results
for metric in METRICS_LIST:
stats["Overall"][metric].append(overall_stats_i[metric])
stats["Simple"][metric].append(simple_stats_i[metric])
stats["Complex"][metric].append(complex_stats_i[metric])

return stats


def get_predictions(idxs, idx_to_edge, y_pred):
"""
Gets edges that are predicted to be target edges for some "block_id".
Parameters
----------
idxs : set
Indices of entries in "y_pred" that belong to a given block.
idx_to_edge : dict
Dictionary that stores the correspondence between an index from
"y_pred" and edge from "neurographs[block_id]" for some "block_id".
y_pred : numpy.ndarray
Prediction of edge probabilities generated by classifcation model.
Returns
-------
set
Edges that are predicted to be target edges for some "block_id".
"""
edge_idxs = set(np.where(y_pred > 0)[0]).intersection(idxs)
return set([idx_to_edge[idx] for idx in edge_idxs])


def __reconstruction_stats(target_graph, pred_graph, pred_edges):
true_positives = 0
false_positives = 0
for edge in pred_edges:
if edge in pred_graph.target_edges:
true_positives += 1
else:
false_positives += 1
return true_positives, false_positives


def __reconstruction_type_stats(target_graph, pred_graph, pred_edges):
simple_stats = dict([(s, 0) for s in STATS_LIST])
complex_stats = dict([(s, 0) for s in STATS_LIST])
for edge in pred_edges:
i, j = tuple(edge)
deg_i = pred_graph.immutable_degree(i)
deg_j = pred_graph.immutable_degree(j)
if edge in pred_graph.target_edges:
if deg_i == 1 and deg_j == 1:
simple_stats["# splits fixed"] += 1
else:
complex_stats["# splits fixed"] += 1
else:
if deg_i == 1 and deg_j == 1:
simple_stats["# merges created"] += 1
else:
complex_stats["# merges created"] += 1

num_simple, num_complex = compute_edge_type(pred_graph)
simple_stats = compute_accuracy(simple_stats, num_simple)
complex_stats = compute_accuracy(complex_stats, num_complex)

if False:
print("# simple edges:", num_simple)
print("% simple edges:", num_simple / (num_complex + num_simple))
print("# complex edges:", num_complex)
print("% complex edges:", num_complex / (num_complex + num_simple))
print("")
return simple_stats, complex_stats


def compute_edge_type(graph):
num_simple = 0
num_complex = 0
for edge in graph.target_edges:
i, j = tuple(edge)
deg_i = graph.immutable_degree(i)
deg_j = graph.immutable_degree(j)
if deg_i == 1 and deg_j == 1:
num_simple += 1
else:
num_complex += 1
return num_simple, num_complex


def compute_accuracy(stats, num_edges):
d = stats["# merges created"] + stats["# splits fixed"]
r = 1 if num_edges == 0 else stats["# splits fixed"] / num_edges
p = 1 if d == 0 else stats["# splits fixed"] / d
stats["f1"] = 0 if r + p == 0 else (2 * r * p) / (r + p)
stats["precision"] = p
stats["recall"] = r
def get_stats(neurograph, proposals, pred_edges):
"""
Accuracy of the predictions generated by an edge classication model on a
given block and "edge_type" (e.g. overall, simple, or complex).
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
y_pred : numpy.ndarray
Binary predictions of edges generated by classifcation model.
Returns
-------
dict
Dictionary containing results of evaluation where the keys are
"METRICS_LIST".
"""
tp, fp, p, r, f1 = get_accuracy(neurograph, proposals, pred_edges)
stats = {
"# splits fixed": tp,
"# merges created": fp,
"precision": p,
"recall": r,
"f1": f1,
}
return stats


def get_accuracy(neurograph, proposals, pred_edges):
"""
Computes the following metrics for a given set of predicted edges:
(1) true positives, (2) false positive, (3) precision, (4) recall, and
(5) f1-score.
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
y_pred : numpy.ndarray
Prediction of edge probabilities generated by classifcation model.
Returns
-------
tp : float
Number of true positives.
fp : float
Number of false positives.
p : float
Precision.
r : float
Recall.
f1 : float
F1-score.
"""
tp, fp, fn = get_accuracy_counts(neurograph, proposals, pred_edges)
p = 1 if tp + fp == 0 else tp / (tp + fp)
r = 1 if tp + fn == 0 else tp / (tp + fn)
f1 = (2 * r * p) / max(r + p, 1e-3)
return tp, fp, p, r, f1


def get_accuracy_counts(neurograph, proposals, pred_edges):
"""
Computes the following values: (1) true positives, (2) false positive, and
(3) false negatives.
Parameters
----------
neurograph : NeuroGraph
Predicted neurograph
proposals : set[frozenset]
Set of edge proposals for a given "edge_type".
y_pred : numpy.ndarray
Prediction of edge probabilities generated by classifcation model.
Returns
-------
tp : float
Number of true positives.
fp : float
Number of false positives.
fn : float
Number of false negatives.
"""
tp = 0
fp = 0
fn = 0
for edge in proposals:
if edge in neurograph.target_edges:
if edge in pred_edges:
tp += 1
else:
fn += 1
elif edge in pred_edges:
fp += 1
return tp, fp, fn
2 changes: 1 addition & 1 deletion src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def generate_img_chunk_features(
xyz_j = utils.world_to_img(neurograph, j)

# Extract chunks
midpoint = geometry_utils.compute_midpoint(xyz_i, xyz_j).astype(int)
midpoint = geometry_utils.get_midpoint(xyz_i, xyz_j).astype(int)
img_chunk = utils.get_chunk(img, midpoint, CHUNK_SIZE)
labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE)

Expand Down
20 changes: 2 additions & 18 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_directional(neurograph, i, proposal_tangent, window=5):
# Compute principle axes
directionals = []
d = neurograph.optimize_depth
for branch in get_branches(neurograph, i):
for branch in neurograph.get_branches(i):
if branch.shape[0] >= window + d:
xyz = deepcopy(branch[d:, :])
elif branch.shape[0] <= d:
Expand All @@ -35,22 +35,6 @@ def get_directional(neurograph, i, proposal_tangent, window=5):
return directionals[arg_max]


def get_branches(neurograph, i):
branches = []
ref_xyz = deepcopy(neurograph.nodes[i]["xyz"])
for j in neurograph.neighbors(i):
if frozenset((i, j)) in neurograph.immutable_edges:
branches.append(orient(neurograph.edges[i, j]["xyz"], ref_xyz))
return branches


def orient(xyz, ref_xyz):
if (xyz[0, :] == ref_xyz).all():
return xyz
else:
return np.flip(xyz, axis=0)


def compute_svd(xyz):
xyz = xyz - np.mean(xyz, axis=0)
return svd(xyz)
Expand All @@ -72,7 +56,7 @@ def compute_normal(xyz):
return normal / np.linalg.norm(normal)


def compute_midpoint(xyz1, xyz2):
def get_midpoint(xyz1, xyz2):
return np.mean([xyz1, xyz2], axis=0)


Expand Down
Loading

0 comments on commit 8198f58

Please sign in to comment.