diff --git a/src/deep_neurographs/machine_learning/feature_generation.py b/src/deep_neurographs/machine_learning/feature_generation.py index 5187eab..dfcd465 100644 --- a/src/deep_neurographs/machine_learning/feature_generation.py +++ b/src/deep_neurographs/machine_learning/feature_generation.py @@ -4,11 +4,11 @@ @author: Anna Grim @email: anna.grim@alleninstitute.org -Generates features for training a model and performing inference. +Generates features for training a machine learning model and performing +inference. Conventions: (1) "xyz" refers to a real world coordinate such as those from an swc file - (2) "voxel" refers to an voxel coordinate in a whole exaspim image. """ @@ -26,7 +26,7 @@ class FeatureGenerator: """ Class that generates features vectors that are used by a graph neural - network to classify proposals. + network (GNN) to classify proposals. """ # Class attributes @@ -54,7 +54,8 @@ def __init__( Path to the segmentation assumed to be stored on a GCS bucket. The default is None. is_multimodal : bool, optional - ... + Indication of whether to generate multimodal features (i.e. image + and label patch for each proposal). The default is False. Returns ------- @@ -118,7 +119,7 @@ def run(self, neurograph, proposals_dict, radius): proposals_dict : dict Dictionary that contains the items (1) "proposals" which are the proposals from "neurograph" that features will be generated and - (2) "graph" which is the computation graph used by the gnn. + (2) "graph" which is the computation graph used by the GNN. radius : float Search radius used to generate proposals. @@ -156,7 +157,7 @@ def run_on_nodes(self, neurograph, computation_graph): neurograph : NeuroGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph - Graph used by gnn to classify proposals. + Graph used by GNN to classify proposals. Returns ------- @@ -175,12 +176,12 @@ def run_on_branches(self, neurograph, computation_graph): neurograph : NeuroGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph - Graph used by gnn to classify proposals. + Graph used by GNN to classify proposals. Returns ------- dict - Dictionary that maps an edge id to a feature vector. + Dictionary that maps an branch id to a feature vector. """ return self.branch_skeletal(neurograph, computation_graph) @@ -221,7 +222,7 @@ def node_skeletal(self, neurograph, computation_graph): neurograph : NeuroGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph - Graph used by gnn to classify proposals. + Graph used by GNN to classify proposals. Returns ------- @@ -250,7 +251,7 @@ def branch_skeletal(self, neurograph, computation_graph): neurograph : NeuroGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph - Graph used by gnn to classify proposals. + Graph used by GNN to classify proposals. Returns ------- @@ -313,7 +314,7 @@ def node_profiles(self, neurograph, computation_graph): neurograph : NeuroGraph NeuroGraph generated from a predicted segmentation. computation_graph : networkx.Graph - Graph used by gnn to classify proposals. + Graph used by GNN to classify proposals. Returns ------- @@ -435,7 +436,7 @@ def get_profile(self, xyz_path, profile_id): def get_spec(self, xyz_path): """ Gets image bounding box and voxel coordinates needed to compute an - image intensity profile or extract image chunk for cnn embedding. + image intensity profile or extract image patch. Parameters ---------- diff --git a/src/deep_neurographs/machine_learning/heterograph_datasets.py b/src/deep_neurographs/machine_learning/heterograph_datasets.py index 4c1a922..361a436 100644 --- a/src/deep_neurographs/machine_learning/heterograph_datasets.py +++ b/src/deep_neurographs/machine_learning/heterograph_datasets.py @@ -64,7 +64,7 @@ def init(neurograph, features, computation_graph): # Build patch matrix is_multimodel = "patches" in features if is_multimodel: - x_dict["patches"] = get_patches_matrix( + x_dict["patch"] = get_patches_matrix( features["patches"], idxs["proposals"]["id_to_idx"] ) @@ -142,17 +142,18 @@ def __init__( ] # Features - self.data = HeteroGraphData() - self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE) - self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE) - self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE) - - # Edges + self.init_nodes(x_dict, y_proposals) self.init_edges() self.check_missing_edge_types() self.init_edge_attrs(x_dict["nodes"]) self.n_edge_attrs = n_edge_features(x_dict["nodes"]) + def init_nodes(self, x_dict, y_proposals): + self.data = HeteroGraphData() + self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE) + self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE) + self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE) + def init_edges(self): """ Initializes edge index for a graph dataset. @@ -430,8 +431,49 @@ def __init__( idxs, ) - # Instance attributes - self.data["patches"].x = torch.tensor(x_dict["patches"], dtype=DTYPE) + def init_nodes(self, x_dict, y_proposals): + self.data = HeteroGraphData() + self.data["branch"].x = torch.tensor(x_dict["branches"], dtype=DTYPE) + self.data["proposal"].x = torch.tensor(x_dict["proposals"], dtype=DTYPE) + self.data["proposal"].y = torch.tensor(y_proposals, dtype=DTYPE) + self.data["patch"].x = torch.tensor(x_dict["patch"], dtype=DTYPE) + + def check_missing_edge_types(self): + for node_type in ["branch", "proposal"]: + edge_type = (node_type, "edge", node_type) + if len(self.data[edge_type].edge_index) == 0: + # Add dummy features - nodes + dtype = self.data[node_type].x.dtype + if node_type == "branch": + d = self.n_branch_features() + else: + d = self.n_proposal_features() + + zeros = torch.zeros(2, d, dtype=dtype) + self.data[node_type].x = torch.cat( + (self.data[node_type].x, zeros), dim=0 + ) + + # Add dummy features - patches + if node_type == "proposal": + patch_shape = self.data["patch"].x.size()[1:] + zeros = torch.zeros((2,) + patch_shape, dtype=dtype) + self.data["patch"].x = torch.cat( + (self.data["patch"].x, zeros), dim=0 + ) + + # Update edge_index + n = self.data[node_type]["x"].size(0) + e_1 = frozenset({-1, -2}) + e_2 = frozenset({-2, -3}) + edges = [[n - 1, n - 2], [n - 2, n - 1]] + self.data[edge_type].edge_index = gnn_util.toTensor(edges) + if node_type == "branch": + self.idxs_branches["idx_to_id"][n - 1] = e_1 + self.idxs_branches["idx_to_id"][n - 2] = e_2 + else: + self.idxs_proposals["idx_to_id"][n - 1] = e_1 + self.idxs_proposals["idx_to_id"][n - 2] = e_2 # -- util -- diff --git a/src/deep_neurographs/machine_learning/heterograph_models.py b/src/deep_neurographs/machine_learning/heterograph_models.py index 3272c94..5d861a3 100644 --- a/src/deep_neurographs/machine_learning/heterograph_models.py +++ b/src/deep_neurographs/machine_learning/heterograph_models.py @@ -45,28 +45,29 @@ def __init__( """ super().__init__() - # Layer dimensions - hidden_dim_1 = hidden_dim - hidden_dim_2 = hidden_dim_1 * heads_2 - output_dim = hidden_dim * heads_1 * heads_2 - # Nonlinear activation self.dropout = dropout self.dropout_layer = Dropout(dropout) self.leaky_relu = LeakyReLU() - # Linear layers + # Initial Embedding self.input_nodes = nn.ModuleDict() - self.input_edges = dict() for key, d in node_dict.items(): - self.input_nodes[key] = nn.Linear(d, hidden_dim_1, device=device) + self.input_nodes[key] = nn.Linear(d, hidden_dim, device=device) + + self.input_edges = dict() for key, d in edge_dict.items(): - self.input_edges[key] = nn.Linear(d, hidden_dim_1, device=device) - self.output = Linear(output_dim, 1).to(device) + self.input_edges[key] = nn.Linear(d, hidden_dim, device=device) + + # Layer dimensions + hidden_dim_1 = hidden_dim + hidden_dim_2 = hidden_dim_1 * heads_2 + output_dim = hidden_dim_1 * heads_1 * heads_2 # Message passing layers self.gat1 = self.init_gat_layer(hidden_dim_1, hidden_dim_1, heads_1) self.gat2 = self.init_gat_layer(hidden_dim_2, hidden_dim_1, heads_2) + self.output = Linear(output_dim, 1).to(device) # Initialize weights self.init_weights() @@ -190,22 +191,41 @@ def __init__( node_dict, edge_dict, device, - hidden_dim, + hidden_dim * 2, dropout, heads_1, heads_2, ) - # Instance attributes - self.input_patches = ConvNet(hidden_dim) + # Patch Embedding + self.input_patches = ConvNet((48, 48, 48), hidden_dim) + + # Node Embedding + proposal_dim = node_dict["proposal"] + branch_dim = node_dict["branch"] + self.input_nodes = nn.ModuleDict({ + "proposal": nn.Linear(proposal_dim, hidden_dim, device=device), + "branch": nn.Linear(branch_dim, hidden_dim * 2, device=device), + }) + + # Edge Embedding + self.input_edges = dict() + for key, d in edge_dict.items(): + self.input_edges[key] = nn.Linear( + d, hidden_dim * 2, device=device + ) + + # Initialize weights + self.init_weights() def forward(self, x_dict, edge_index_dict, edge_attr_dict): # Input - Patches - x_patches = self.input_patches(x_dict["patches"]) - del x_dict["patches"] + x_patch = self.input_patches(x_dict["patch"]) + del x_dict["patch"] # Input - Nodes - x_dict = {key: f(x_dict[key]) for key, f in self.input_nodes.items()} + for key, f in self.input_nodes.items(): + x_dict[key] = f(x_dict[key]) x_dict = self.activation(x_dict) # Input - Edges @@ -214,6 +234,7 @@ def forward(self, x_dict, edge_index_dict, edge_attr_dict): edge_attr_dict = self.activation(edge_attr_dict) # Concatenate multimodal embeddings + x_dict["proposal"] = torch.cat((x_dict["proposal"], x_patch), dim=1) # Message passing layers x_dict = self.gat1( diff --git a/src/deep_neurographs/machine_learning/models.py b/src/deep_neurographs/machine_learning/models.py index 1562c4f..e14ca14 100644 --- a/src/deep_neurographs/machine_learning/models.py +++ b/src/deep_neurographs/machine_learning/models.py @@ -108,7 +108,7 @@ def __init__(self, patch_shape, output_dim): self.conv1 = self._init_conv_layer(2, 32) self.conv2 = self._init_conv_layer(32, 64) self.output = nn.Sequential( - nn.Linear(-1, 64), + nn.Linear(64000, 64), nn.LeakyReLU(), nn.Linear(output_dim, output_dim), )