Skip to content

Commit

Permalink
upd : training routines
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim committed Jan 1, 2024
1 parent 726cdd8 commit 43b0dad
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 105 deletions.
5 changes: 4 additions & 1 deletion doc_template/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@
html_theme = "furo"
html_static_path = ["_static"]
html_favicon = "_static/favicon.ico"
html_theme_options = {"light_logo": "light-logo.svg", "dark_logo": "dark-logo.svg"}
html_theme_options = {
"light_logo": "light-logo.svg",
"dark_logo": "dark-logo.svg",
}

# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
html_show_sphinx = False
Expand Down
12 changes: 10 additions & 2 deletions src/deep_neurographs/deep_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def _init_conv_layer(self, in_channels, out_channels):
"""
conv_layer = nn.Sequential(
nn.Conv3d(
in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=0
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=1,
padding=0,
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
Expand Down Expand Up @@ -227,7 +231,11 @@ def _init_conv_layer(self, in_channels, out_channels):
"""
conv_layer = nn.Sequential(
nn.Conv3d(
in_channels, out_channels, kernel_size=(3, 3, 3), stride=1, padding=0
in_channels,
out_channels,
kernel_size=(3, 3, 3),
stride=1,
padding=0,
),
nn.BatchNorm3d(out_channels),
nn.LeakyReLU(),
Expand Down
139 changes: 111 additions & 28 deletions src/deep_neurographs/deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
binary_recall,
)

from deep_neurographs import feature_extraction as extracter
from deep_neurographs.deep_learning import datasets as ds
from deep_neurographs.deep_learning import models

Expand All @@ -46,10 +47,26 @@


# -- Cross Validation --
def get_kfolds(train_data, k):
def get_kfolds(filenames, k):
"""
Partitions "filenames" into k-folds to perform cross validation.
Parameters
----------
filenames : list[str]
List of filenames of samples for training.
k : int
Number of folds to be used in k-fold cross validation.
Returns
-------
folds : list[list[str]]
Partition of "filesnames" into k-folds.
"""
folds = []
samples = set(train_data)
num_samples = int(np.floor(len(train_data) / k))
samples = set(filenames)
num_samples = int(np.floor(len(filenames) / k))
assert num_samples > 0, "Sample size is too small for {}-folds".format(k)
for i in range(k):
samples_i = sample(samples, num_samples)
Expand All @@ -61,39 +78,102 @@ def get_kfolds(train_data, k):


# -- Training --
def get_clf(key, data=None, num_features=None):
assert key in SUPPORTED_MODELS
if key == "AdaBoost":
def fit_model(
model_type, X, y, lr=1e-3, logger=False, max_epochs=50, profile=False
):
"""
Fits a model to a training dataset.
Parameters
----------
model_type : str
Indication of type of model. Options are "AdaBoost",
"RandomForest", "FeedForwardNet", "ConvNet", and
"MultiModalNet".
X : numpy.ndarray
Feature matrix.
y : numpy.ndarray
Labels to be learned.
lr : float, optional
Learning rate to be used if model is a neural network. The default is
1e-3.
logger : bool, optional
Indication of whether to log performance stats while neural network
trains. The default is False.
max_epochs : int, optional
Maximum number of epochs used to train neural network. The default is
50.
profile : bool, optional
Indication of whether to profile runtime of training neural network.
The default is False.
Returns
-------
...
"""
if model_type in ["FeedForwardNet", "ConvNet", "MultiModalNet"]:
data = {"inputs": X, "labels": y}
net, dataset = get_model(model_type, data=data)
model = train_network(
net, dataset, logger=logger, lr=lr, max_epochs=max_epochs
)
else:
model = get_model(model_type)
model.fit(X, y)
return model


def evaluate_model():
pass


def get_model(model_type, data=None):
"""
Gets classification model to be fit.
Parameters
----------
model_type : str
Indication of type of model. Options are "AdaBoost",
"RandomForest", "FeedForwardNet", "ConvNet", and
"MultiModalNet".
data : dict, optional
Training data used to fit model. This dictionary must contain the keys
"inputs" and "labels" which correspond to the feature matrix and
target labels to be learned. The default is None.
Returns
-------
...
"""
assert model_type in SUPPORTED_MODELS
if model_type == "AdaBoost":
return AdaBoostClassifier()
elif key == "RandomForest":
elif model_type == "RandomForest":
return RandomForestClassifier()
elif key == "FeedForwardNet":
net = models.FeedForwardNet(num_features)
train_data = ds.ProposalDataset(data["inputs"], data["labels"])
elif key == "ConvNet":
elif model_type == "FeedForwardNet":
n_features = extracter.count_features(model_type)
net = models.FeedForwardNet(n_features)
dataset = ds.ProposalDataset(data["inputs"], data["labels"])
elif model_type == "ConvNet":
net = models.ConvNet()
models.init_weights(net)
train_data = ds.ImgProposalDataset(
dataset = ds.ImgProposalDataset(
data["inputs"], data["labels"], transform=True
)
elif key == "MultiModalNet":
net = models.MultiModalNet(num_features)
elif model_type == "MultiModalNet":
n_features = extracter.count_features(model_type)
net = models.MultiModalNet(n_features)
models.init_weights(net)
train_data = ds.MultiModalDataset(
dataset = ds.MultiModalDataset(
data["inputs"], data["labels"], transform=True
)
return net, train_data
return net, dataset


def train_network(
net,
dataset,
logger=True,
lr=10e-3,
max_epochs=100,
model_summary=True,
profile=False,
progress_bar=True,
net, dataset, logger=True, lr=1e-3, max_epochs=50, profile=False
):
# Load data
train_set, valid_set = random_split(dataset)
Expand All @@ -105,7 +185,10 @@ def train_network(
shuffle=SHUFFLE,
)
valid_loader = DataLoader(
valid_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True
valid_set,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=True,
)

# Configure trainer
Expand All @@ -118,8 +201,8 @@ def train_network(
accelerator="gpu",
callbacks=[ckpt_callback],
devices=1,
enable_model_summary=model_summary,
enable_progress_bar=progress_bar,
enable_model_summary=True,
enable_progress_bar=True,
logger=logger,
log_every_n_steps=1,
max_epochs=max_epochs,
Expand Down Expand Up @@ -149,7 +232,7 @@ def eval_network(X, model):

# -- Lightning Module --
class LitNeuralNet(pl.LightningModule):
def __init__(self, net=None, lr=10e-3):
def __init__(self, net=None, lr=1e-3):
super().__init__()
self.net = net
self.lr = lr
Expand Down
14 changes: 12 additions & 2 deletions src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
"""
import numpy as np

METRICS_LIST = ["precision", "recall", "f1", "# splits fixed", "# merges created"]
METRICS_LIST = [
"precision",
"recall",
"f1",
"# splits fixed",
"# merges created",
]


def init_stats():
Expand Down Expand Up @@ -52,7 +58,11 @@ def run_evaluation(neurographs, blocks, pred_edges):
metrics contained in this dictionary are identical to "METRICS_LIST"].
"""
stats = {"Overall": init_stats(), "Simple": init_stats(), "Complex": init_stats()}
stats = {
"Overall": init_stats(),
"Simple": init_stats(),
"Complex": init_stats(),
}
for block_id in blocks:
# Compute accuracy
overall_stats_i = get_stats(
Expand Down
Loading

0 comments on commit 43b0dad

Please sign in to comment.