Skip to content

Commit

Permalink
feature : edge optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 1, 2023
1 parent 46f5745 commit acd0825
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 191 deletions.
12 changes: 5 additions & 7 deletions src/deep_neurographs/deep_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def __getitem__(self, idx):
"""
if self.transform:
inputs = utils.normalize_img(self.inputs[idx])
inputs = self.transform.run(inputs)
inputs = self.transform.run(self.inputs[idx])
else:
inputs = self.inputs[idx]
return {"inputs": inputs, "labels": self.labels[idx]}
Expand Down Expand Up @@ -217,8 +216,7 @@ def __getitem__(self, idx):
"""
if self.transform:
img_inputs = utils.normalize_img(self.img_inputs[idx])
img_inputs = self.transform.run(img_inputs)
img_inputs = self.transform.run(self.img_inputs[idx])
else:
img_inputs = self.img_inputs[idx]
inputs = [self.feature_inputs[idx], img_inputs]
Expand Down Expand Up @@ -250,9 +248,9 @@ def __init__(self):
tio.RandomBlur(std=(0, 0.4)),
tio.RandomNoise(std=(0, 0.0125)),
tio.RandomFlip(axes=(0, 1, 2)),
# tio.RandomAffine(
# degrees=20, scales=(0.8, 1), image_interpolation="nearest"
# )
tio.RandomAffine(
degrees=20, scales=(0.8, 1), image_interpolation="nearest"
)
]
)

Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/deep_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _init_conv_layer(self, in_channels, out_channels):
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
nn.Dropout(p=0.25),
nn.Dropout(p=0.2),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=2),
)
return conv_layer
Expand Down
25 changes: 18 additions & 7 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def train_network(
net,
dataset,
logger=True,
lr=10e-3,
max_epochs=100,
model_summary=True,
profile=False,
Expand All @@ -111,24 +112,29 @@ def train_network(
)

# Configure trainer
model = LitNeuralNet(net)
checkpoint_callback = ModelCheckpoint(
model = LitNeuralNet(net=net, lr=lr)
ckpt_callback = ModelCheckpoint(
save_top_k=1, monitor="val_f1", mode="max"
)
profiler = PyTorchProfiler() if profile else None

# Fit model
trainer = pl.Trainer(
accelerator="gpu",
callbacks=[checkpoint_callback],
callbacks=[ckpt_callback],
devices=1,
enable_model_summary=model_summary,
enable_progress_bar=progress_bar,
logger=logger,
log_every_n_steps=1,
max_epochs=max_epochs,
profiler=profiler,
)
trainer.fit(model, train_loader, valid_loader)

# Return best model
ckpt = torch.load(ckpt_callback.best_model_path)
model.net.load_state_dict(ckpt["state_dict"])
return model


Expand All @@ -141,22 +147,24 @@ def random_split(train_set, train_ratio=0.85):
def eval_network(X, model, threshold=0.5):
model.eval()
X = torch.tensor(X, dtype=torch.float32)
y_pred = sigmoid(model.net(X))
with torch.no_grad():
y_pred = sigmoid(model.net(X))
return np.array(y_pred > threshold, dtype=int)


# Lightning Module
class LitNeuralNet(pl.LightningModule):
def __init__(self, net):
def __init__(self, net=None, lr=10e-3):
super().__init__()
self.net = net

self.lr = lr

def forward(self, batch):
x = self.get_example(batch, "inputs")
return self.net(x)

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer

def training_step(self, batch, batch_idx):
Expand All @@ -183,3 +191,6 @@ def compute_stats(self, y_hat, y, prefix=""):

def get_example(self, batch, key):
return batch[key]

def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.net.state_dict(destination, prefix + '', keep_vars)
6 changes: 2 additions & 4 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ def check_aligned(self, pred_xyz_i, pred_xyz_j):
pred_xyz_j = np.array(pred_xyz_j)
pred_dist = dist(pred_xyz_i, pred_xyz_j)

target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j)
target_dist = max(target_dist, 1)

# Check criteria
target_path, target_dist = self.connect_nodes(graph_id, xyz_i, xyz_j)
ratio = min(pred_dist, target_dist) / max(pred_dist, target_dist)
if ratio < 0.5 and pred_dist > 15:
if ratio < 0.5 and pred_dist > 10:
return False
else:
return True
33 changes: 14 additions & 19 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def generate_mutable_img_chunk_features(
img, labels = utils.get_superchunks(
img_path, labels_path, origin, neurograph.shape, from_center=False
)

#img = utils.normalize_img(img)
for edge in neurograph.mutable_edges:
# Compute image coordinates
i, j = tuple(edge)
Expand All @@ -86,18 +88,22 @@ def generate_mutable_img_chunk_features(
# Mark path
d = int(geometry_utils.dist(xyz_i, xyz_j) + 5)
img_coords_i = np.round(xyz_i - midpoint + HALF_CHUNK_SIZE).astype(int)
img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int)
img_coords_j = np.round(xyz_j - midpoint + HALF_CHUNK_SIZE).astype(int)
path = geometry_utils.make_line(img_coords_i, img_coords_j, d)

img_chunk = utils.normalize_img(img_chunk)
labels_chunk[labels_chunk > 0] = 1
labels_chunk = geometry_utils.fill_path(labels_chunk, path, val=-1)
labels_chunk = geometry_utils.fill_path(labels_chunk, path)
features[edge] = np.stack([img_chunk, labels_chunk], axis=0)

return features


def get_local_img_coords(neurograph, i):
global_xyz = deepcopy(neurograph.nodes[i]["xyz"])
def get_local_img_coords(neurograph, node_or_xyz):
if type(node_or_xyz) == int:
global_xyz = deepcopy(neurograph.nodes[node_or_xyz]["xyz"])
else:
global_xyz = node_or_xyz
local_xyz = utils.apply_anisotropy(
global_xyz - np.array(neurograph.origin)
)
Expand All @@ -109,16 +115,17 @@ def generate_mutable_img_profile_features(
):
features = dict()
origin = utils.apply_anisotropy(neurograph.origin, return_int=True)
superchunk = utils.get_superchunk(
img = utils.get_superchunk(
path, "zarr", origin, neurograph.shape, from_center=False
)
img = utils.normalize_img(img)
for edge in neurograph.mutable_edges:
i, j = tuple(edge)
xyz_i = get_local_img_coords(neurograph, i)
xyz_j = get_local_img_coords(neurograph, j)
line = geometry_utils.make_line(xyz_i, xyz_j, NUM_POINTS)
features[edge] = geometry_utils.get_profile(
superchunk, line, window_size=WINDOW_SIZE
img, line, window_size=WINDOW_SIZE
)
return features

Expand All @@ -132,7 +139,7 @@ def generate_mutable_skel_features(neurograph):
ddot1, ddot2, ddot3 = get_directionals(neurograph, edge, 10)
features[edge] = np.concatenate(
(
compute_length(neurograph, edge),
neurograph.compute_length(edge),
neurograph.immutable_degree(i),
neurograph.immutable_degree(j),
radius_i,
Expand Down Expand Up @@ -295,15 +302,3 @@ def combine_features(features):
(combined[edge], features[key][edge])
)
return combined


"""
def get_chunk(superchunk, xyz):
return deepcopy(
superchunk[
(xyz[0] - CHUNK_SIZE[0] // 2) : xyz[0] + CHUNK_SIZE[0] // 2,
(xyz[1] - CHUNK_SIZE[1] // 2) : xyz[1] + CHUNK_SIZE[1] // 2,
(xyz[2] - CHUNK_SIZE[2] // 2) : xyz[2] + CHUNK_SIZE[2] // 2,
]
)
"""
67 changes: 62 additions & 5 deletions src/deep_neurographs/geometry_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import heapq
import networkx as nx
import numpy as np
from scipy.interpolate import UnivariateSpline
from scipy.linalg import svd
Expand Down Expand Up @@ -110,9 +112,9 @@ def smooth_branch(xyz):
def fit_spline(xyz):
s = xyz.shape[0] / 5
t = np.arange(xyz.shape[0])
cs_x = UnivariateSpline(t, xyz[:, 0], s=s, k=3)
cs_y = UnivariateSpline(t, xyz[:, 1], s=s, k=3)
cs_z = UnivariateSpline(t, xyz[:, 2], s=s, k=3)
cs_x = UnivariateSpline(t, xyz[:, 0], s=s, k=1)
cs_y = UnivariateSpline(t, xyz[:, 1], s=s, k=1)
cs_z = UnivariateSpline(t, xyz[:, 2], s=s, k=1)
return cs_x, cs_y, cs_z


Expand All @@ -137,12 +139,67 @@ def get_profile(img, xyz_arr, window_size=[5, 5, 5]):
def fill_path(img, path, val=-1):
for xyz in path:
x, y, z = tuple(np.floor(xyz).astype(int))
img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val
# img[x,y,z] = val
#img[x - 1 : x + 2, y - 1 : y + 2, z - 1 : z + 2] = val
img[x,y,z] = val
return img


# Miscellaneous
def shortest_path(img, start, end):
def is_valid_move(x, y, z):
return 0 <= x < shape[0] and 0 <= y < shape[1] and 0 <= z < shape[2] and not visited[x, y, z]

def get_nbs(x, y, z):
moves = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
return [(x + dx, y + dy, z + dz) for dx, dy, dz in moves if is_valid_move(x + dx, y + dy, z + dz)]

img = img - np.min(img) + 1
start = tuple(start)
end = tuple(end)

shape = img.shape
visited = np.zeros(shape, dtype=bool)
distances = np.inf * np.ones(shape)
distances[start] = 0
previous_nodes = {}

heap = [(0, start)]
while heap:
current_distance, cur_node = heapq.heappop(heap)

if cur_node == end:
path = []
while cur_node != start:
path.append(cur_node)
cur_node = previous_nodes[cur_node]
path.append(start)
return path[::-1]

visited[cur_node] = True

for nb in get_nbs(*cur_node):
if not visited[nb]:
new_distance = distances[cur_node] + 1 / img[nb]
if new_distance < distances[nb]:
distances[nb] = new_distance
previous_nodes[nb] = cur_node
heapq.heappush(heap, (new_distance, nb))
return None


def transform_path(path, img_origin, patch_centroid, patch_dims):
img_origin = np.array(img_origin)
transformed_path = np.zeros((len(path), 3))
for i, xyz in enumerate(path):
hat_xyz = utils.patch_to_img(xyz, patch_centroid, patch_dims)
transformed_path[i, :] = utils.to_world(hat_xyz, shift=-img_origin)
return smooth_branch(transformed_path)


def get_optimal_patch(xyz_1, xyz_2, buffer=8):
return [int(abs(xyz_1[i] - xyz_2[i])) + buffer for i in range(3)]


def compare_edges(xyx_i, xyz_j, xyz_k):
dist_ij = dist(xyx_i, xyz_j)
dist_ik = dist(xyx_i, xyz_k)
Expand Down
Loading

0 comments on commit acd0825

Please sign in to comment.