Skip to content

Commit

Permalink
add feature : optimize edge alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 12, 2023
1 parent 363333e commit 21cad5b
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 122 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_evaluation(neurographs, blocks, pred_edges):
overall_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].mutable_edges,
pred_edges[block_id]
pred_edges[block_id],
)

simple_stats_i = get_stats(
Expand Down
11 changes: 4 additions & 7 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

CHUNK_SIZE = [64, 64, 64]
HALF_CHUNK_SIZE = [CHUNK_SIZE[i] // 2 for i in range(3)]
WINDOW_SIZE = [5, 5, 5]
WINDOW = [5, 5, 5]

NUM_POINTS = 10
NUM_IMG_FEATURES = NUM_POINTS
Expand Down Expand Up @@ -85,7 +85,7 @@ def generate_img_chunk_features(
labels_chunk = utils.get_chunk(labels, midpoint, CHUNK_SIZE)

# Mark path
if neurograph.optimize_proposals:
if neurograph.optimize_alignment:
xyz_list = to_patch_coords(neurograph, edge, midpoint)
path = geometry_utils.sample_path(xyz_list, NUM_POINTS)
else:
Expand Down Expand Up @@ -120,19 +120,16 @@ def generate_img_profile_features(
path, "zarr", origin, neurograph.shape, from_center=False
)
img = utils.normalize_img(img)
simple_edges = neurograph.get_simple_proposals()
for edge in neurograph.mutable_edges:
if neurograph.optimize_proposals and edge in simple_edges:
if neurograph.optimize_alignment:
xyz = to_img_coords(neurograph, edge)
path = geometry_utils.sample_path(xyz, NUM_POINTS)
else:
i, j = tuple(edge)
xyz_i = utils.world_to_img(neurograph, i)
xyz_j = utils.world_to_img(neurograph, j)
path = geometry_utils.make_line(xyz_i, xyz_j, NUM_POINTS)
features[edge] = geometry_utils.get_profile(
img, path, window_size=WINDOW_SIZE
)
features[edge] = geometry_utils.get_profile(img, path, window=WINDOW)
return features


Expand Down
249 changes: 231 additions & 18 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

# Directional Vectors
def get_directional(neurograph, i, proposal_tangent, window=5):
# Compute principle axes
directionals = []
d = neurograph.optimize_depth
for branch in neurograph.get_branches(i):
Expand All @@ -20,7 +19,6 @@ def get_directional(neurograph, i, proposal_tangent, window=5):
xyz = deepcopy(branch)
else:
xyz = deepcopy(branch[d : window + d, :])
# print(xyz)
directionals.append(compute_tangent(xyz))

# Determine best
Expand Down Expand Up @@ -72,9 +70,9 @@ def smooth_branch(xyz, s=None):
def fit_spline(xyz, s=None):
s = xyz.shape[0] / 5 if not s else xyz.shape[0] / s
t = np.linspace(0, 1, xyz.shape[0])
spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=1)
spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=1)
spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=1)
spline_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3)
spline_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3)
spline_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3)
return spline_x, spline_y, spline_z


Expand All @@ -86,8 +84,8 @@ def sample_path(path, num_points):


# Image feature extraction
def get_profile(img, xyz_arr, window_size=[5, 5, 5]):
return [np.max(utils.get_chunk(img, xyz, window_size)) for xyz in xyz_arr]
def get_profile(img, xyz_arr, window=[5, 5, 5]):
return [np.max(utils.get_chunk(img, xyz, window)) for xyz in xyz_arr]


def fill_path(img, path, val=-1):
Expand All @@ -98,9 +96,224 @@ def fill_path(img, path, val=-1):
return img


# Miscellaneous
# Proposal optimization
def optimize_alignment(neurograph, img, edge, depth=15):
"""
Optimizes alignment of edge proposal between two branches by finding
straight path with the brightest averaged image profile.
Parameters
----------
neurograph : NeuroGraph
Predicted neuron reconstruction to be corrected.
img : numpy.ndarray
Image chunk that the reconstruction is contained in.
edge : frozenset
Edge proposal to be aligned.
depth : int, optional
Maximum depth checked during alignment optimization. The default value
is 15.
Returns
-------
numpy.ndarray, numpy.ndarray
xyz coordinates of aligned edge proposal.
"""
if neurograph.is_simple(edge):
return optimize_simple_alignment(neurograph, img, edge, depth=depth)
else:
return optimize_complex_alignment(neurograph, img, edge, depth=depth)


def optimize_simple_alignment(neurograph, img, edge, depth=15):
"""
Optimizes alignment of edge proposal for simple edges.
Parameters
----------
neurograph : NeuroGraph
Predicted neuron reconstruction to be corrected.
img : numpy.ndarray
Image chunk that the reconstruction is contained in.
edge : frozenset
Edge proposal to be aligned.
depth : int, optional
Maximum depth checked during alignment optimization. The default value
is 15.
Returns
-------
numpy.ndarray, numpy.ndarray
xyz coordinates of aligned edge proposal.
"""
i, j = tuple(edge)
branch_i = neurograph.get_branch(i)
branch_j = neurograph.get_branch(j)
xyz_i, xyz_j, _ = align(neurograph, img, branch_i, branch_j, depth)
return xyz_i, xyz_j


def optimize_complex_alignment(neurograph, img, edge, depth=15):
"""
Optimizes alignment of edge proposal for complex edges.
Parameters
----------
neurograph : NeuroGraph
Predicted neuron reconstruction to be corrected.
img : numpy.ndarray
Image chunk that the reconstruction is contained in.
edge : frozenset
Edge proposal to be aligned.
depth : int, optional
Maximum depth checked during alignment optimization. The default value
is 15.
Returns
-------
numpy.ndarray, numpy.ndarray
xyz coordinates of aligned edge proposal.
"""
i, j = tuple(edge)
branch = neurograph.get_branch(i if neurograph.is_leaf(i) else j)
branches = neurograph.get_branches(j if neurograph.is_leaf(i) else i)
xyz_1, leaf_1, val_1 = align(neurograph, img, branch, branches[0], depth)
xyz_2, leaf_2, val_2 = align(neurograph, img, branch, branches[1], depth)
return (xyz_1, leaf_1) if val_1 > val_2 else (xyz_2, leaf_2)


def align(neurograph, img, branch_1, branch_2, depth):
"""
Finds straight line path between end points of "branch_1" and "branch_2"
that best captures the image signal. This path is determined by checking
the average image intensity of the line drawn from "branch_1[d_1]" and
"branch_2[d_2]" with d_1, d_2 in [0, depth].
Parameters
----------
neurograph : NeuroGraph
Predicted neuron reconstruction to be corrected.
img : numpy.ndarray
Image chunk that the reconstruction is contained in.
branch_1 : np.ndarray
Branch corresponding to some predicted neuron. This branch must be
oriented so that the end points being considered are the coordinates
in rows 0 through "depth".
branch_2 : np.ndarray
Branch corresponding to some predicted neuron. This branch must be
oriented so that the end points being considered are the coordinates
in rows 0 through "depth".
depth : int
Maximum depth of branch that is optimized over.
Returns
-------
best_xyz_1 : np.ndarray
Optimal xyz coordinate from "branch_1".
best_xyz_2 : np.ndarray
Optimal xyz coordinate from "branch_2".
best_score : float
Average brightness of voxels sampled along line between "best_xyz_1"
and "best_xyz_2".
"""
best_xyz_1 = None
best_xyz_2 = None
best_score = 0
for d_1 in range(min(depth, len(branch_1) - 1)):
xyz_1 = neurograph.to_img(branch_1[d_1])
for d_2 in range(min(depth, len(branch_2) - 1)):
xyz_2 = neurograph.to_img(branch_2[d_2])
line = make_line(xyz_1, xyz_2, 10)
score = np.mean(get_profile(img, line, window=[3, 3, 3]))
if score > best_score:
best_score = score
best_xyz_1 = deepcopy(xyz_1)
best_xyz_2 = deepcopy(xyz_2)
return best_xyz_1, best_xyz_2, best_score


def optimize_path(img, origin, xyz_1, xyz_2):
"""
Finds optimal path between "xyz_1" and "xyz_2" that best captures the
image signal. The path is determined by finding the shortest path these
points with respect the cost function f(xyz) = 1 / img[xyz].
Parameters
----------
img : np.ndarray
Image chunk that contains "start" and "end". The centroid of this img
is "origin".
origin : np.ndarray
The xyz-coordinate (in world coordinates) of "img".
xyz_1 : np.ndarray
The xyz-coordinate (in image coordinates) of the start point of the
path.
xyz_2 : np.ndarray
The xyz-coordinate (in image coordinates) of the end point of the
path.
Returns
-------
list[tuple[float]]
Optimal path between "xyz_1" and "xyz_2".
"""
patch_dims = get_optimal_patch(xyz_1, xyz_2, buffer=5)
center = get_midpoint(xyz_1, xyz_2).astype(int)
img_chunk = utils.get_chunk(img, center, patch_dims)
path = shortest_path(
img_chunk,
utils.img_to_patch(xyz_1, center, patch_dims),
utils.img_to_patch(xyz_2, center, patch_dims),
)
return transform_path(path, origin, center, patch_dims)


def shortest_path(img, start, end):
"""
Finds shortest path between "start" and "end" with respect to the image
intensity values.
Parameters
----------
img : np.ndarray
Image chunk that "start" and "end" are contained within and domain of
the shortest path.
start : np.ndarray
Start point of path.
end : np.ndarray
End point of path.
Returns
-------
list[tuple]
Shortest path between "start" and "end".
"""

def is_valid_move(x, y, z):
"""
Determines whether (x, y, z) coordinate is contained in image.
Parameters
----------
x : int
X-coordinate.
y : int
Y-coordinate.
z : int
Z-coordinate.
Returns
-------
bool
Indication of whether coordinate is contained in image.
"""
return (
0 <= x < shape[0]
and 0 <= y < shape[1]
Expand Down Expand Up @@ -170,28 +383,28 @@ def get_optimal_patch(xyz_1, xyz_2, buffer=8):
return [int(abs(xyz_1[i] - xyz_2[i])) + buffer for i in range(3)]


def compare_edges(xyx_i, xyz_j, xyz_k):
dist_ij = dist(xyx_i, xyz_j)
dist_ik = dist(xyx_i, xyz_k)
return dist_ij < dist_ik


def dist(x, y, metric="l2"):
# Miscellaneous
def dist(v_1, v_2, metric="l2"):
"""
Computes distance between "x" and "y".
Computes distance between "v_1" and "v_2".
Parameters
----------
v_1 : np.ndarray
Vector.
v_2 : np.ndarray
Vector.
Returns
-------
float
Distance between "v_1" and "v_2".
"""
if metric == "l1":
return np.linalg.norm(np.subtract(x, y), ord=1)
return np.linalg.norm(np.subtract(v_1, v_2), ord=1)
else:
return np.linalg.norm(np.subtract(x, y), ord=2)
return np.linalg.norm(np.subtract(v_1, v_2), ord=2)


def make_line(xyz_1, xyz_2, num_steps):
Expand Down
16 changes: 12 additions & 4 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def build_neurograph(
search_radius=25.0,
prune=True,
prune_depth=16,
optimize_proposals=False,
optimize_depth=15,
optimize_alignment=True,
optimize_path=False,
origin=None,
shape=None,
smooth=True,
Expand All @@ -38,7 +40,9 @@ def build_neurograph(
neurograph = NeuroGraph(
swc_dir,
img_path=img_path,
optimize_proposals=optimize_proposals,
optimize_depth=optimize_depth,
optimize_alignment=optimize_alignment,
optimize_path=optimize_path,
origin=origin,
shape=shape,
)
Expand Down Expand Up @@ -71,8 +75,12 @@ def init_immutables(

for path in get_paths(neurograph.path):
swc_id = get_id(path)
raw_swc = swc_utils.read_swc(path)
swc_dict = swc_utils.parse(raw_swc, anisotropy=anisotropy)
swc_dict = swc_utils.parse(
swc_utils.read_swc(path),
anisotropy=anisotropy,
bbox=neurograph.bbox,
img_shape=neurograph.shape,
)
if len(swc_dict["xyz"]) < size_threshold:
continue
if smooth:
Expand Down
Loading

0 comments on commit 21cad5b

Please sign in to comment.