Skip to content

Commit

Permalink
enh: suggestions to PR (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban authored and jhlegarreta committed Jun 8, 2024
1 parent 908c71d commit e65c250
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 312 deletions.
1 change: 0 additions & 1 deletion docs/developers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,4 @@ Information on specific functions, classes, and methods.
api/eddymotion.math
api/eddymotion.model
api/eddymotion.utils
api/eddymotion.validation
api/eddymotion.viz
6 changes: 4 additions & 2 deletions src/eddymotion/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
#


class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.
class ModelNotFittedError(ValueError, AttributeError):
"""
Exception class to raise if estimator is used before fitting.
This class inherits from both ValueError and AttributeError to help with
exception handling.
"""
47 changes: 36 additions & 11 deletions src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed

from eddymotion.validation import check_is_fitted
from eddymotion.exceptions import ModelNotFittedError


def _exec_fit(model, data, chunk=None):
Expand Down Expand Up @@ -92,12 +92,15 @@ class BaseModel:
"_b_max",
"_models",
"_datashape",
"_is_fitted",
)
_modelargs = ()

def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
"""Base initialization."""

self._is_fitted = False

# Setup B0 map
self._S0 = None
if S0 is not None:
Expand Down Expand Up @@ -138,6 +141,10 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
self._models = None
self._is_fitted = False

@property
def is_fitted(self):
return self._is_fitted

def fit(self, data, n_jobs=None, **kwargs):
"""Fit the model chunk-by-chunk asynchronously"""
n_jobs = n_jobs or 1
Expand Down Expand Up @@ -172,7 +179,9 @@ def fit(self, data, n_jobs=None, **kwargs):

def predict(self, gradient, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""
check_is_fitted(self)

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

if self._b_max is not None:
gradient[-1] = min(gradient[-1], self._b_max)
Expand Down Expand Up @@ -215,10 +224,6 @@ def predict(self, gradient, **kwargs):

return retval

def __model_is_fitted__(self):
"""Check fitted status and return a Boolean value."""
return hasattr(self, "_is_fitted") and self._is_fitted


class TrivialB0Model:
"""A trivial model that returns a *b=0* map always."""
Expand All @@ -232,22 +237,24 @@ def __init__(self, S0=None, **kwargs):

self._S0 = S0

@property
def is_fitted(self):
return True

def fit(self, *args, **kwargs):
"""Do nothing."""
# ToDo
# Does not inherit from BaseModel, so should be defined in __init__ ??
self._is_fitted = True

def predict(self, gradient, **kwargs):
"""Return the *b=0* map."""
check_is_fitted(self)

# No need to check fit (if not fitted, has raised already)
return self._S0


class AverageDWModel:
"""A trivial model that returns an average map."""

__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat")
__slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted")

def __init__(self, **kwargs):
r"""
Expand Down Expand Up @@ -276,6 +283,7 @@ def __init__(self, **kwargs):
self._bias = kwargs.get("bias", True)
self._stat = kwargs.get("stat", "median")
self._data = None
self._is_fitted = False

def fit(self, data, **kwargs):
"""Calculate the average."""
Expand All @@ -301,8 +309,18 @@ def fit(self, data, **kwargs):
# Calculate the average
self._data = avg_func(shells, axis=-1)

self._is_fitted = self._data is not None

@property
def is_fitted(self):
return self._is_fitted

def predict(self, gradient, **kwargs):
"""Return the average map."""

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

return self._data


Expand Down Expand Up @@ -351,6 +369,10 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3,
self._shape = None
self._coeff = None

@property
def is_fitted(self):
return self._coeff is not None

def fit(self, data, *args, **kwargs):
"""Fit the model."""
from scipy.interpolate import BSpline
Expand Down Expand Up @@ -386,6 +408,9 @@ def predict(self, timepoint, **kwargs):
"""Return the *b=0* map."""
from scipy.interpolate import BSpline

if not self._is_fitted:
raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting")

# Project sample timing into B-Spline coordinates
x = (timepoint / self._xlim) * self._n_ctrl
A = BSpline.design_matrix(x, self._t, k=self._order)
Expand Down
178 changes: 0 additions & 178 deletions src/eddymotion/tests/test_validation.py

This file was deleted.

Loading

0 comments on commit e65c250

Please sign in to comment.