Skip to content

Commit

Permalink
major upd : feature_extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Dec 21, 2023
1 parent cf3c1e3 commit 726cdd8
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 211 deletions.
5 changes: 1 addition & 4 deletions doc_template/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
html_theme = "furo"
html_static_path = ["_static"]
html_favicon = "_static/favicon.ico"
html_theme_options = {
"light_logo": "light-logo.svg",
"dark_logo": "dark-logo.svg",
}
html_theme_options = {"light_logo": "light-logo.svg", "dark_logo": "dark-logo.svg"}

# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
html_show_sphinx = False
Expand Down
12 changes: 2 additions & 10 deletions src/deep_neurographs/deep_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def _init_conv_layer(self, in_channels, out_channels):
"""
conv_layer = nn.Sequential(
nn.Conv3d(
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=1,
padding=0,
in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=0
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
Expand Down Expand Up @@ -231,11 +227,7 @@ def _init_conv_layer(self, in_channels, out_channels):
"""
conv_layer = nn.Sequential(
nn.Conv3d(
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=1,
padding=0,
in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=0
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
Expand Down
14 changes: 6 additions & 8 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
BATCH_SIZE = 32
NUM_WORKERS = 0
SHUFFLE = True
SUPPORTED_CLFS = [
SUPPORTED_MODELS = [
"AdaBoost",
"RandomForest",
"FeedForwardNet",
Expand All @@ -45,7 +45,7 @@
]


# Training
# -- Cross Validation --
def get_kfolds(train_data, k):
folds = []
samples = set(train_data)
Expand All @@ -60,8 +60,9 @@ def get_kfolds(train_data, k):
return folds


# -- Training --
def get_clf(key, data=None, num_features=None):
assert key in SUPPORTED_CLFS
assert key in SUPPORTED_MODELS
if key == "AdaBoost":
return AdaBoostClassifier()
elif key == "RandomForest":
Expand Down Expand Up @@ -104,10 +105,7 @@ def train_network(
shuffle=SHUFFLE,
)
valid_loader = DataLoader(
valid_set,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=True,
valid_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True
)

# Configure trainer
Expand Down Expand Up @@ -149,7 +147,7 @@ def eval_network(X, model):
return np.array(y_pred)


# Lightning Module
# -- Lightning Module --
class LitNeuralNet(pl.LightningModule):
def __init__(self, net=None, lr=10e-3):
super().__init__()
Expand Down
32 changes: 20 additions & 12 deletions src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@
"""
import numpy as np

METRICS_LIST = [
"precision",
"recall",
"f1",
"# splits fixed",
"# merges created",
]
METRICS_LIST = ["precision", "recall", "f1", "# splits fixed", "# merges created"]


def init_stats():
"""
Initializes a dictionary that stores stats computes by routines in this
module.
Parameters
----------
None
Returns
-------
dict
Dictionary that stores stats computes by routines in this module.
"""
return dict([(metric, []) for metric in METRICS_LIST])


def run_evaluation(neurographs, blocks, pred_edges):
Expand All @@ -40,11 +52,7 @@ def run_evaluation(neurographs, blocks, pred_edges):
metrics contained in this dictionary are identical to "METRICS_LIST"].
"""
stats = {
"Overall": dict([(metric, []) for metric in METRICS_LIST]),
"Simple": dict([(metric, []) for metric in METRICS_LIST]),
"Complex": dict([(metric, []) for metric in METRICS_LIST]),
}
stats = {"Overall": init_stats(), "Simple": init_stats(), "Complex": init_stats()}
for block_id in blocks:
# Compute accuracy
overall_stats_i = get_stats(
Expand Down
Loading

0 comments on commit 726cdd8

Please sign in to comment.