From 8198f58b5ca7d3f7ca675582f996e91c2c3584e3 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Thu, 7 Dec 2023 20:33:52 +0000 Subject: [PATCH] major upd : evaluation pipeline --- src/deep_neurographs/deep_learning/train.py | 4 +- src/deep_neurographs/evaluation.py | 269 +++++++++++++------- src/deep_neurographs/feature_extraction.py | 2 +- src/deep_neurographs/geometry_utils.py | 20 +- src/deep_neurographs/neurograph.py | 69 +++-- 5 files changed, 228 insertions(+), 136 deletions(-) diff --git a/src/deep_neurographs/deep_learning/train.py b/src/deep_neurographs/deep_learning/train.py index 4ec9725..35d1dec 100644 --- a/src/deep_neurographs/deep_learning/train.py +++ b/src/deep_neurographs/deep_learning/train.py @@ -141,12 +141,12 @@ def random_split(train_set, train_ratio=0.85): return torch_data.random_split(train_set, [train_set_size, valid_set_size]) -def eval_network(X, model, threshold=0.5): +def eval_network(X, model): model.eval() X = torch.tensor(X, dtype=torch.float32) with torch.no_grad(): y_pred = sigmoid(model.net(X)) - return np.array(y_pred > threshold, dtype=int) + return np.array(y_pred) # Lightning Module diff --git a/src/deep_neurographs/evaluation.py b/src/deep_neurographs/evaluation.py index d7d66cf..44c78a6 100644 --- a/src/deep_neurographs/evaluation.py +++ b/src/deep_neurographs/evaluation.py @@ -4,12 +4,12 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Evaluates performance of edge classifier. +Evaluates performance of edge classifiation model. """ import numpy as np -STATS_LIST = [ +METRICS_LIST = [ "precision", "recall", "f1", @@ -18,108 +18,193 @@ ] -def run_evaluation( - target_graphs, pred_graphs, y_pred, block_to_idxs, idx_to_edge, blocks -): - stats = dict([(s, []) for s in STATS_LIST]) - stats_by_type = { - "simple": dict([(s, []) for s in STATS_LIST]), - "complex": dict([(s, []) for s in STATS_LIST]), +def run_evaluation(neurographs, blocks, pred_edges): + """ + Runs an evaluation on the accuracy of the predictions generated by an edge + classication model. + + Parameters + ---------- + neurographs : list[NeuroGraph] + Predicted neurographs. + y_pred : numpy.ndarray + Binary predictions of edges generated by classifcation model. + blocks_to_idxs : dict + Dictionary that stores which indices in "y_pred" correspond to edges + in "neurographs[block_id]". + idx_to_edge : dict + Dictionary that stores the correspondence between an index from + "y_pred" and edge from "neurographs[block_id]" for some "block_id". + blocks : list[str] + List of block_ids that indicate which predictions to evaluate. + + Returns + ------- + stats : dict + Dictionary that stores the accuracy of the edge classification model + on all edges (i.e. "Overall"), simple edges, and complex edges. The + metrics contained in this dictionary are identical to "METRICS_LIST"]. + + """ + stats = { + "Overall": dict([(metric, []) for metric in METRICS_LIST]), + "Simple": dict([(metric, []) for metric in METRICS_LIST]), + "Complex": dict([(metric, []) for metric in METRICS_LIST]), } - print(blocks) for block_id in blocks: - # Get predicted edges - pred_edges = get_predictions( - block_to_idxs[block_id], idx_to_edge, y_pred + # Compute accuracy + overall_stats_i = get_stats( + neurographs[block_id], + neurographs[block_id].mutable_edges, + pred_edges[block_id] ) - # Overall performance - num_fixes, num_mistakes = __reconstruction_stats( - target_graphs[block_id], pred_graphs[block_id], pred_edges + simple_stats_i = get_stats( + neurographs[block_id], + neurographs[block_id].get_simple_proposals(), + pred_edges[block_id], ) - stats["# splits fixed"].append(num_fixes) - stats["# merges created"].append(num_mistakes) - # In-depth performance - simple_stats, complex_stats = __reconstruction_type_stats( - target_graphs[block_id], pred_graphs[block_id], pred_edges + complex_stats_i = get_stats( + neurographs[block_id], + neurographs[block_id].get_complex_proposals(), + pred_edges[block_id], ) - if True: - print("simple stats:", simple_stats) - print("complex stats:", complex_stats) - print("") - for key in STATS_LIST: - stats_by_type["simple"][key].append(simple_stats[key]) - stats_by_type["complex"][key].append(complex_stats[key]) - return stats, stats_by_type + + # Store results + for metric in METRICS_LIST: + stats["Overall"][metric].append(overall_stats_i[metric]) + stats["Simple"][metric].append(simple_stats_i[metric]) + stats["Complex"][metric].append(complex_stats_i[metric]) + + return stats def get_predictions(idxs, idx_to_edge, y_pred): + """ + Gets edges that are predicted to be target edges for some "block_id". + + Parameters + ---------- + idxs : set + Indices of entries in "y_pred" that belong to a given block. + idx_to_edge : dict + Dictionary that stores the correspondence between an index from + "y_pred" and edge from "neurographs[block_id]" for some "block_id". + y_pred : numpy.ndarray + Prediction of edge probabilities generated by classifcation model. + + Returns + ------- + set + Edges that are predicted to be target edges for some "block_id". + + """ edge_idxs = set(np.where(y_pred > 0)[0]).intersection(idxs) return set([idx_to_edge[idx] for idx in edge_idxs]) -def __reconstruction_stats(target_graph, pred_graph, pred_edges): - true_positives = 0 - false_positives = 0 - for edge in pred_edges: - if edge in pred_graph.target_edges: - true_positives += 1 - else: - false_positives += 1 - return true_positives, false_positives - - -def __reconstruction_type_stats(target_graph, pred_graph, pred_edges): - simple_stats = dict([(s, 0) for s in STATS_LIST]) - complex_stats = dict([(s, 0) for s in STATS_LIST]) - for edge in pred_edges: - i, j = tuple(edge) - deg_i = pred_graph.immutable_degree(i) - deg_j = pred_graph.immutable_degree(j) - if edge in pred_graph.target_edges: - if deg_i == 1 and deg_j == 1: - simple_stats["# splits fixed"] += 1 - else: - complex_stats["# splits fixed"] += 1 - else: - if deg_i == 1 and deg_j == 1: - simple_stats["# merges created"] += 1 - else: - complex_stats["# merges created"] += 1 - - num_simple, num_complex = compute_edge_type(pred_graph) - simple_stats = compute_accuracy(simple_stats, num_simple) - complex_stats = compute_accuracy(complex_stats, num_complex) - - if False: - print("# simple edges:", num_simple) - print("% simple edges:", num_simple / (num_complex + num_simple)) - print("# complex edges:", num_complex) - print("% complex edges:", num_complex / (num_complex + num_simple)) - print("") - return simple_stats, complex_stats - - -def compute_edge_type(graph): - num_simple = 0 - num_complex = 0 - for edge in graph.target_edges: - i, j = tuple(edge) - deg_i = graph.immutable_degree(i) - deg_j = graph.immutable_degree(j) - if deg_i == 1 and deg_j == 1: - num_simple += 1 - else: - num_complex += 1 - return num_simple, num_complex - - -def compute_accuracy(stats, num_edges): - d = stats["# merges created"] + stats["# splits fixed"] - r = 1 if num_edges == 0 else stats["# splits fixed"] / num_edges - p = 1 if d == 0 else stats["# splits fixed"] / d - stats["f1"] = 0 if r + p == 0 else (2 * r * p) / (r + p) - stats["precision"] = p - stats["recall"] = r +def get_stats(neurograph, proposals, pred_edges): + """ + Accuracy of the predictions generated by an edge classication model on a + given block and "edge_type" (e.g. overall, simple, or complex). + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neurograph + proposals : set[frozenset] + Set of edge proposals for a given "edge_type". + y_pred : numpy.ndarray + Binary predictions of edges generated by classifcation model. + + Returns + ------- + dict + Dictionary containing results of evaluation where the keys are + "METRICS_LIST". + + """ + tp, fp, p, r, f1 = get_accuracy(neurograph, proposals, pred_edges) + stats = { + "# splits fixed": tp, + "# merges created": fp, + "precision": p, + "recall": r, + "f1": f1, + } return stats + + +def get_accuracy(neurograph, proposals, pred_edges): + """ + Computes the following metrics for a given set of predicted edges: + (1) true positives, (2) false positive, (3) precision, (4) recall, and + (5) f1-score. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neurograph + proposals : set[frozenset] + Set of edge proposals for a given "edge_type". + y_pred : numpy.ndarray + Prediction of edge probabilities generated by classifcation model. + + Returns + ------- + tp : float + Number of true positives. + fp : float + Number of false positives. + p : float + Precision. + r : float + Recall. + f1 : float + F1-score. + + """ + tp, fp, fn = get_accuracy_counts(neurograph, proposals, pred_edges) + p = 1 if tp + fp == 0 else tp / (tp + fp) + r = 1 if tp + fn == 0 else tp / (tp + fn) + f1 = (2 * r * p) / max(r + p, 1e-3) + return tp, fp, p, r, f1 + + +def get_accuracy_counts(neurograph, proposals, pred_edges): + """ + Computes the following values: (1) true positives, (2) false positive, and + (3) false negatives. + + Parameters + ---------- + neurograph : NeuroGraph + Predicted neurograph + proposals : set[frozenset] + Set of edge proposals for a given "edge_type". + y_pred : numpy.ndarray + Prediction of edge probabilities generated by classifcation model. + + Returns + ------- + tp : float + Number of true positives. + fp : float + Number of false positives. + fn : float + Number of false negatives. + + """ + tp = 0 + fp = 0 + fn = 0 + for edge in proposals: + if edge in neurograph.target_edges: + if edge in pred_edges: + tp += 1 + else: + fn += 1 + elif edge in pred_edges: + fp += 1 + return tp, fp, fn diff --git a/src/deep_neurographs/feature_extraction.py b/src/deep_neurographs/feature_extraction.py index 17d2b69..4ceffd4 100644 --- a/src/deep_neurographs/feature_extraction.py +++ b/src/deep_neurographs/feature_extraction.py @@ -80,7 +80,7 @@ def generate_img_chunk_features( xyz_j = utils.world_to_img(neurograph, j) # Extract chunks - midpoint = geometry_utils.compute_midpoint(xyz_i, xyz_j).astype(int) + midpoint = geometry_utils.get_midpoint(xyz_i, xyz_j).astype(int) img_chunk = utils.get_chunk(img, midpoint, CHUNK_SIZE) labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE) diff --git a/src/deep_neurographs/geometry_utils.py b/src/deep_neurographs/geometry_utils.py index 3279fca..a2dd7d9 100644 --- a/src/deep_neurographs/geometry_utils.py +++ b/src/deep_neurographs/geometry_utils.py @@ -13,7 +13,7 @@ def get_directional(neurograph, i, proposal_tangent, window=5): # Compute principle axes directionals = [] d = neurograph.optimize_depth - for branch in get_branches(neurograph, i): + for branch in neurograph.get_branches(i): if branch.shape[0] >= window + d: xyz = deepcopy(branch[d:, :]) elif branch.shape[0] <= d: @@ -35,22 +35,6 @@ def get_directional(neurograph, i, proposal_tangent, window=5): return directionals[arg_max] -def get_branches(neurograph, i): - branches = [] - ref_xyz = deepcopy(neurograph.nodes[i]["xyz"]) - for j in neurograph.neighbors(i): - if frozenset((i, j)) in neurograph.immutable_edges: - branches.append(orient(neurograph.edges[i, j]["xyz"], ref_xyz)) - return branches - - -def orient(xyz, ref_xyz): - if (xyz[0, :] == ref_xyz).all(): - return xyz - else: - return np.flip(xyz, axis=0) - - def compute_svd(xyz): xyz = xyz - np.mean(xyz, axis=0) return svd(xyz) @@ -72,7 +56,7 @@ def compute_normal(xyz): return normal / np.linalg.norm(normal) -def compute_midpoint(xyz1, xyz2): +def get_midpoint(xyz1, xyz2): return np.mean([xyz1, xyz2], axis=0) diff --git a/src/deep_neurographs/neurograph.py b/src/deep_neurographs/neurograph.py index 2951cd9..5f7d405 100644 --- a/src/deep_neurographs/neurograph.py +++ b/src/deep_neurographs/neurograph.py @@ -39,7 +39,7 @@ def __init__( swc_path, img_path=None, label_mask=None, - optimize_depth=5, + optimize_depth=10, optimize_proposals=False, origin=None, shape=None, @@ -81,6 +81,18 @@ def __init__( else: self.bbox = None + def init_immutable_graph(self): + immutable_graph = nx.Graph() + immutable_graph.add_nodes_from(self) + immutable_graph.add_edges_from(self.immutable_edges) + return immutable_graph + + def init_predicted_graph(self): + self.predicted_graph = self.init_immutable_graph() + + def init_densegraph(self): + self.densegraph = DenseGraph(self.path) + # --- Add nodes or edges --- def generate_immutables( self, swc_id, swc_dict, prune=True, prune_depth=16 @@ -145,12 +157,6 @@ def generate_immutables( # Build kdtree self._init_kdtree() - def init_immutable_graph(self): - immutable_graph = nx.Graph() - immutable_graph.add_nodes_from(self) - immutable_graph.add_edges_from(self.immutable_edges) - return immutable_graph - # --- Proposal Generation --- def generate_proposals(self, num_proposals=3, search_radius=25.0): """ @@ -313,7 +319,6 @@ def run_optimization(self): img = utils.get_superchunk( self.img_path, "zarr", origin, self.shape, from_center=False ) - img = utils.normalize_img(img) simple_edges = self.get_simple_proposals() for edge in self.mutable_edges: if edge in simple_edges: @@ -332,9 +337,7 @@ def optimize_simple_edge(self, img, edge): hat_xyz_i = self.to_img(branch_i[depth]) hat_xyz_j = self.to_img(branch_j[depth]) patch_dims = geometry_utils.get_optimal_patch(hat_xyz_i, hat_xyz_j) - center = geometry_utils.compute_midpoint(hat_xyz_i, hat_xyz_j).astype( - int - ) + center = geometry_utils.get_midpoint(hat_xyz_i, hat_xyz_j).astype(int) img_chunk = utils.get_chunk(img, center, patch_dims) # Optimize @@ -349,6 +352,18 @@ def optimize_simple_edge(self, img, edge): [branch_i[depth], path, branch_j[depth]] ) + def optimize_complex_edge(self, img, edge): + # Extract Branches + i, j = tuple(edge) + leaf = i if self.immutable_degree(i) == 1 else j + i = j if leaf == i else i + branches = self.get_branches(i) + depth = self.optimize_depth + + # Search for best anchor + #if len(branches) == 2: + + def get_branch(self, xyz): edge = self.xyz_to_edge[tuple(xyz)] branch = self.edges[edge]["xyz"] @@ -357,15 +372,25 @@ def get_branch(self, xyz): else: return branch - def optimize_complex_edge(self, img_superchunk, edge): - pass + def get_branches(self, i): + branches = [] + for j in self.neighbors(i): + if frozenset((i, j)) in self.immutable_edges: + branches.append(self.orient_edge((i, j), i)) + return branches + + def orient_edge(self, edge, i): + if (self.edges[edge]["xyz"][0, :] == self.nodes[i]["xyz"]).all(): + return self.edges[edge]["xyz"] + else: + return np.flip(self.edges[edge]["xyz"], axis=0) # --- Ground Truth Generation --- def init_targets(self, target_neurograph): # Initializations self.target_edges = set() - self.groundtruth_graph = self.init_immutable_graph() - target_neurograph.densegraph = DenseGraph(target_neurograph.path) + self.init_predicted_graph() + target_neurograph.init_densegraph() # Add best simple edges remaining_proposals = [] @@ -431,7 +456,7 @@ def is_target( ): # Check for cycle i, j = tuple(edge) - if self.check_cycle((i, j)): + if self.creates_cycle((i, j)): return False # Check projection distance @@ -644,13 +669,13 @@ def is_contained(self, node_or_xyz): def is_leaf(self, i): return True if self.immutable_degree(i) == 1 else False - def check_cycle(self, edge): - self.groundtruth_graph.add_edges_from([edge]) + def creates_cycle(self, edge): + self.predicted_graph.add_edges_from([edge]) try: - nx.find_cycle(self.groundtruth_graph) + nx.find_cycle(self.predicted_graph) except: return False - self.groundtruth_graph.remove_edges_from([edge]) + self.predicted_graph.remove_edges_from([edge]) return True def get_edge_attr(self, key, i, j): @@ -659,9 +684,7 @@ def get_edge_attr(self, key, i, j): return attr_1, attr_2 def get_center(self): - return geometry_utils.compute_midpoint( - self.bbox["min"], self.bbox["max"] - ) + return geometry_utils.get_midpoint(self.bbox["min"], self.bbox["max"]) def get_complex_proposals(self): return set([e for e in self.mutable_edges if not self.is_simple(e)])