Skip to content

Commit

Permalink
Merge pull request #177 from decargroup/bugfix/176-fix-scikit-learn-i…
Browse files Browse the repository at this point in the history
…mport-method-resolution-order

Fix `sklearn` method resolution order
  • Loading branch information
sdahdah authored Sep 26, 2024
2 parents fdf5099 + bbd875c commit 21073ad
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pykoop/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


class KernelApproximation(
sklearn.base.BaseEstimator,
sklearn.base.TransformerMixin,
sklearn.base.BaseEstimator,
metaclass=abc.ABCMeta,
):
"""Base class for all kernel approximations.
Expand Down
11 changes: 7 additions & 4 deletions pykoop/koopman_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@


class KoopmanLiftingFn(
sklearn.base.BaseEstimator,
sklearn.base.TransformerMixin,
sklearn.base.BaseEstimator,
metaclass=abc.ABCMeta,
):
"""Base class for Koopman lifting functions.
Expand Down Expand Up @@ -1138,9 +1138,11 @@ def _validate_parameters(self) -> None:
raise NotImplementedError()


class KoopmanRegressor(sklearn.base.BaseEstimator,
sklearn.base.RegressorMixin,
metaclass=abc.ABCMeta):
class KoopmanRegressor(
sklearn.base.RegressorMixin,
sklearn.base.BaseEstimator,
metaclass=abc.ABCMeta,
):
"""Base class for Koopman regressors.
All attributes with a trailing underscore are set by :func:`fit`.
Expand Down Expand Up @@ -1628,6 +1630,7 @@ def _more_tags(self):
return {
'multioutput': True,
'multioutput_only': True,
'requires_y': False,
}


Expand Down
4 changes: 3 additions & 1 deletion pykoop/lmi_regressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _more_tags(self):
return {
'multioutput': True,
'multioutput_only': True,
'requires_y': False,
'_xfail_checks': {
'check_fit_idempotent': reason,
}
Expand Down Expand Up @@ -2155,7 +2156,7 @@ def _create_problem_b(self, U: np.ndarray) -> picos.Problem:
return problem_b


class LmiHinfZpkMeta(sklearn.base.BaseEstimator, sklearn.base.RegressorMixin):
class LmiHinfZpkMeta(sklearn.base.RegressorMixin, sklearn.base.BaseEstimator):
"""Meta-estimator where H-infinity weight is specified in ZPK format.
H-infinity regularization weights must normally be specified in
Expand Down Expand Up @@ -2375,6 +2376,7 @@ def _more_tags(self):
return {
'multioutput': True,
'multioutput_only': True,
'requires_y': False,
}


Expand Down
2 changes: 2 additions & 0 deletions pykoop/regressors.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def _more_tags(self):
return {
'multioutput': True,
'multioutput_only': True,
'requires_y': False,
'_xfail_checks': {
'check_estimators_dtypes': reason,
'check_fit_score_takes_y': reason,
Expand Down Expand Up @@ -504,6 +505,7 @@ def _more_tags(self):
return {
'multioutput': True,
'multioutput_only': True,
'requires_y': False,
# Allow a bad score since the ``coef_`` matrix will be filled with
# zeros, and we just care to test ``scikit-learn`` API compliance.
'poor_score': True,
Expand Down

0 comments on commit 21073ad

Please sign in to comment.