Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Check if fit has been called prior to predict #166

Merged

Conversation

jhlegarreta
Copy link
Collaborator

@jhlegarreta jhlegarreta commented Apr 13, 2024

Check if fit has been called prior to predict: raise an exception if not.

Resolves #113.

@jhlegarreta
Copy link
Collaborator Author

jhlegarreta commented Apr 13, 2024

For now, only added it for TrivialB0Model to first see how this fits into what was expected in the issue.

A few notes:

  • Following ENH: PET uptake model #112 (comment), almost all of this has been borrowed from scikit-learn. If we keep this approach, maybe a mention to this will need to be made.
  • In scikit-learn:
    • Their base estimator does not have any is_fitted_ variable. It does not have a fit/predict either; they are left for child classes.
    • They rely on a list of attributes (coef_, estimator_) existing for the estimator and provided as a parameter https://github.com/scikit-learn/scikit-learn/blob/872124551/sklearn/utils/validation.py#L1551. If none is provided, the estimator is considered fitted if there exists an attribute that ends with an underscore and does not start with a double underscore (e.g. could be the is_fitted_ attribute).
  • Their __sklearn_is_fitted__ is defined in child classes, and is used as an alternative to the above two possibilities to check whether an estimator is fit.
  • For now, I've left the new modules in the eddymotion package, but they should probably be moved to an e.g. utils modules.

Useful references:
https://scikit-learn.org/stable/developers/develop.html#developer-api-for-check-is-fitted
https://scikit-learn.org/stable/auto_examples/developing_estimators/sklearn_is_fitted.html#sphx-glr-auto-examples-developing-estimators-sklearn-is-fitted-py

Slightly related:

  • eddymotion models (e.g. TrivialB0Model, etc.) do not inherit from BaseModel. Not sure if that is intended, but it is counter-intuitive or the rationale would need to be documented better.
  • Adds to the confusion having an EddyMotionEstimator class that has a fit but not a predict. Maybe model vs. estimator naming should be reworked.

No tests added for now until the above is figured out.

@jhlegarreta jhlegarreta force-pushed the CheckModelsAreFitBeforePredicting branch from aa1b0c7 to b96cd24 Compare April 13, 2024 17:55
oesteban
oesteban previously approved these changes Apr 15, 2024
Copy link
Member

@oesteban oesteban left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I very much agree with the suggestions. I also agree that fit() was a very poor choice for the estimator object. Probably optimize() or search() would have been much better names. +1 to updating.

At this point, I'm not even sure the estimator should be an object at all.

src/eddymotion/model/base.py Outdated Show resolved Hide resolved
src/eddymotion/validation.py Outdated Show resolved Hide resolved
src/eddymotion/validation.py Outdated Show resolved Hide resolved
Check if `fit` has been called prior to `predict`: raise an exception if
not.

Add the accompanying tests.
@jhlegarreta
Copy link
Collaborator Author

eddymotion models (e.g. TrivialB0Model, etc.) do not inherit from BaseModel. Not sure if that is intended, but it is counter-intuitive or the rationale would need to be documented better.

Should they inherit from BaseModel?

I very much agree with the suggestions. I also agree that fit() was a very poor choice for the estimator object. Probably optimize() or search() would have been much better names. +1 to updating.

Not sure if optimize() or search() are less confusing to me. If I read what its fit returns, compute_transform seems more descriptive and unambiguous wrt a fit/predict paradigm. To be addressed in a separate PR.

Concerning c440406, the test_model tests will not pass on CircleCI since the models do not inherit from BaseModel. If they are not supposed to inherit from it, I would need to relocate the __model_is_fitted__ method or give the attributes to the check_is_fitted calls in TrivialB0Model, and AverageDWModel.

@oesteban
Copy link
Member

Should they inherit from BaseModel?

I believe they do not inherit just for historical reasons (they were written the first, then we created BaseModel to provide a gateway for DIPY models and never did the last step of including them in the hierarchy). So, the answer is yes.

Concerning c440406, the test_model tests will not pass on CircleCI since the models do not inherit from BaseModel. If they are not supposed to inherit from it, I would need to relocate the __model_is_fitted__ method or give the attributes to the check_is_fitted calls in TrivialB0Model, and AverageDWModel.

Yup, I think making them part of the hierarchy makes all sense.

Not sure if optimize() or search() are less confusing to me. If I read what its fit returns, compute_transform seems more descriptive and unambiguous wrt a fit/predict paradigm. To be addressed in a separate PR.

Yes, in a different PR :). compute_transform feels a bit wordy and even confusing (in the sense that we are actually computing a lot of transforms). What about Estimator.estimate()?

@jhlegarreta
Copy link
Collaborator Author

Yes, in a different PR :). compute_transform feels a bit wordy and even confusing (in the sense that we are actually computing a lot of transforms). What about Estimator.estimate()?

PR #173.

Making this a draft as addressing the inheritance issue is more complicated that what it seems.

@jhlegarreta jhlegarreta force-pushed the CheckModelsAreFitBeforePredicting branch from bc2931e to d98f586 Compare June 7, 2024 23:45
@jhlegarreta jhlegarreta marked this pull request as ready for review June 7, 2024 23:46
@jhlegarreta jhlegarreta force-pushed the CheckModelsAreFitBeforePredicting branch 2 times, most recently from 539b45e to aeb503d Compare June 8, 2024 00:19
@jhlegarreta jhlegarreta force-pushed the CheckModelsAreFitBeforePredicting branch from aeb503d to e65c250 Compare June 8, 2024 00:40
@oesteban oesteban merged commit 59600ee into nipreps:main Jun 8, 2024
6 checks passed
@jhlegarreta jhlegarreta deleted the CheckModelsAreFitBeforePredicting branch June 8, 2024 13:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Track model fit formally with an object member variable
2 participants