Skip to content

Commit

Permalink
Merge pull request #471 from scikit-learn-contrib/212-predict_params_…
Browse files Browse the repository at this point in the history
…mapie_without_classification

Add predict_params into Mapie except classification files
  • Loading branch information
LacombeLouis authored Aug 2, 2024
2 parents 2823152 + 44370b7 commit 603b5da
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 32 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ History
0.8.x (2024-xx-xx)
------------------

* Add `** predict_params` in fit and predict method for Mapie Regression
* Update the ts-changepoint notebook with the tutorial
* Change import related to conformity scores into ts-changepoint notebook
* Replace `assert np.array_equal` by `np.testing.assert_array_equal` in Mapie unit tests
Expand Down
2 changes: 1 addition & 1 deletion mapie/estimator/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def predict(
self,
X: ArrayLike,
agg_scores: Optional[str] = None,
**predict_params
**predict_params,
) -> NDArray:
"""
Predict target from X. It also computes the prediction per train sample
Expand Down
30 changes: 22 additions & 8 deletions mapie/estimator/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def _predict_oof_estimator(
estimator: RegressorMixin,
X: ArrayLike,
val_index: ArrayLike,
**predict_params
) -> Tuple[NDArray, ArrayLike]:
"""
Perform predictions on a single out-of-fold model on a validation set.
Expand All @@ -248,14 +249,17 @@ def _predict_oof_estimator(
val_index: ArrayLike of shape (n_samples_val)
Validation data indices.
**predict_params : dict
Additional predict parameters.
Returns
-------
Tuple[NDArray, ArrayLike]
Predictions of estimator from val_index of X.
"""
X_val = _safe_indexing(X, val_index)
if _num_samples(X_val) > 0:
y_pred = estimator.predict(X_val)
y_pred = estimator.predict(X_val, **predict_params)
else:
y_pred = np.array([])
return y_pred, val_index
Expand Down Expand Up @@ -306,7 +310,7 @@ def _aggregate_with_mask(
else:
raise ValueError("The value of self.agg_function is not correct")

def _pred_multi(self, X: ArrayLike) -> NDArray:
def _pred_multi(self, X: ArrayLike, **predict_params) -> NDArray:
"""
Return a prediction per train sample for each test sample, by
aggregation with matrix ``k_``.
Expand All @@ -316,12 +320,15 @@ def _pred_multi(self, X: ArrayLike) -> NDArray:
X: ArrayLike of shape (n_samples_test, n_features)
Input data
**predict_params : dict
Additional predict parameters.
Returns
-------
NDArray of shape (n_samples_test, n_samples_train)
"""
y_pred_multi = np.column_stack(
[e.predict(X) for e in self.estimators_]
[e.predict(X, **predict_params) for e in self.estimators_]
)
# At this point, y_pred_multi is of shape
# (n_samples_test, n_estimators_). The method
Expand All @@ -334,7 +341,8 @@ def predict_calib(
self,
X: ArrayLike,
y: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None
groups: Optional[ArrayLike] = None,
**predict_params
) -> NDArray:
"""
Perform predictions on X : the calibration set.
Expand All @@ -355,6 +363,9 @@ def predict_calib(
By default ``None``.
**predict_params : dict
Additional predict parameters.
Returns
-------
NDArray of shape (n_samples_test, 1)
Expand All @@ -371,7 +382,7 @@ def predict_calib(
cv = cast(BaseCrossValidator, self.cv)
outputs = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
delayed(self._predict_oof_estimator)(
estimator, X, calib_index,
estimator, X, calib_index, **predict_params
)
for (_, calib_index), estimator in zip(
cv.split(X, y, groups),
Expand Down Expand Up @@ -404,7 +415,7 @@ def fit(
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
**fit_params,
**fit_params
) -> EnsembleRegressor:
"""
Fit the base estimator under the ``single_estimator_`` attribute.
Expand Down Expand Up @@ -526,6 +537,9 @@ def predict(
predictions (3 arrays). If ``False`` the method return the
simple predictions only.
**predict_params : dict
Additional predict parameters.
Returns
-------
Tuple[NDArray, NDArray, NDArray]
Expand All @@ -535,15 +549,15 @@ def predict(
"""
check_is_fitted(self, self.fit_attributes)

y_pred = self.single_estimator_.predict(X)
y_pred = self.single_estimator_.predict(X, **predict_params)
if not return_multi_pred and not ensemble:
return y_pred

if self.method in self.no_agg_methods_ or self.use_split_method_:
y_pred_multi_low = y_pred[:, np.newaxis]
y_pred_multi_up = y_pred[:, np.newaxis]
else:
y_pred_multi = self._pred_multi(X)
y_pred_multi = self._pred_multi(X, **predict_params)

if self.method == "minmax":
y_pred_multi_low = np.min(y_pred_multi, axis=1, keepdims=True)
Expand Down
6 changes: 5 additions & 1 deletion mapie/regression/quantile_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def predict(
optimize_beta: bool = False,
allow_infinite_bounds: bool = False,
symmetry: Optional[bool] = True,
**predict_params,
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
"""
Predict target on new samples with confidence intervals.
Expand Down Expand Up @@ -676,6 +677,9 @@ def predict(
each residuals separatly or to use the maximum of the two
combined.
predict_params : dict
Additional predict parameters.
Returns
-------
Union[NDArray, Tuple[NDArray, NDArray]]
Expand All @@ -699,7 +703,7 @@ def predict(
dtype=float,
)
for i, est in enumerate(self.estimators_):
y_preds[i] = est.predict(X)
y_preds[i] = est.predict(X, **predict_params)
check_lower_upper_bounds(y_preds[0], y_preds[1], y_preds[2])
if symmetry:
quantile = np.full(
Expand Down
30 changes: 23 additions & 7 deletions mapie/regression/regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Iterable, Optional, Tuple, Union, cast
from typing import Any, Iterable, Optional, Tuple, Union, cast

import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
Expand All @@ -19,7 +19,8 @@
from mapie.utils import (check_alpha, check_alpha_and_n_samples,
check_cv, check_estimator_fit_predict,
check_n_features_in, check_n_jobs, check_null_weight,
check_verbose, get_effective_calibration_samples)
check_verbose, get_effective_calibration_samples,
check_predict_params)


class MapieRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -471,7 +472,7 @@ def fit(
y: ArrayLike,
sample_weight: Optional[ArrayLike] = None,
groups: Optional[ArrayLike] = None,
**fit_params,
**kwargs: Any
) -> MapieRegressor:
"""
Fit estimator and compute conformity scores used for
Expand Down Expand Up @@ -504,14 +505,21 @@ def fit(
train/test set.
By default ``None``.
**fit_params : dict
Additional fit parameters.
kwargs : dict
Additional fit and predict parameters.
Returns
-------
MapieRegressor
The model itself.
"""
fit_params = kwargs.pop('fit_params', {})
predict_params = kwargs.pop('predict_params', {})
if len(predict_params) > 0:
self._predict_params = True
else:
self._predict_params = False

# Checks
(estimator,
self.conformity_score_function_,
Expand All @@ -538,7 +546,9 @@ def fit(
)

# Predict on calibration data
y_pred = self.estimator_.predict_calib(X, y=y, groups=groups)
y_pred = self.estimator_.predict_calib(
X, y=y, groups=groups, **predict_params
)

# Compute the conformity scores (manage jk-ab case)
self.conformity_scores_ = \
Expand All @@ -555,6 +565,7 @@ def predict(
alpha: Optional[Union[float, Iterable[float]]] = None,
optimize_beta: bool = False,
allow_infinite_bounds: bool = False,
**predict_params
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
"""
Predict target on new samples with confidence intervals.
Expand Down Expand Up @@ -604,6 +615,9 @@ def predict(
By default ``False``.
predict_params : dict
Additional predict parameters.
Returns
-------
Union[NDArray, Tuple[NDArray, NDArray]]
Expand All @@ -614,14 +628,16 @@ def predict(
- [:, 1, :]: Upper bound of the prediction interval.
"""
# Checks
if hasattr(self, '_predict_params'):
check_predict_params(self._predict_params, predict_params, self.cv)
check_is_fitted(self, self.fit_attributes)
self._check_ensemble(ensemble)
alpha = cast(Optional[NDArray], check_alpha(alpha))

# If alpha is None, predict the target without confidence intervals
if alpha is None:
y_pred = self.estimator_.predict(
X, ensemble, return_multi_pred=False
X, ensemble, return_multi_pred=False, **predict_params
)
return np.array(y_pred)

Expand Down
9 changes: 7 additions & 2 deletions mapie/regression/time_series_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def predict(
alpha: Optional[Union[float, Iterable[float]]] = None,
optimize_beta: bool = False,
allow_infinite_bounds: bool = False,
**predict_params
) -> Union[NDArray, Tuple[NDArray, NDArray]]:
"""
Predict target on new samples with confidence intervals.
Expand Down Expand Up @@ -441,6 +442,9 @@ def predict(
allow_infinite_bounds: bool
Allow infinite prediction intervals to be produced.
predict_params : dict
Additional predict parameters.
Returns
-------
Union[NDArray, Tuple[NDArray, NDArray]]
Expand All @@ -452,15 +456,16 @@ def predict(
"""
if alpha is None:
super().predict(
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta,
**predict_params
)

if self.method == "aci":
alpha = self._get_alpha(alpha)

return super().predict(
X, ensemble=ensemble, alpha=alpha, optimize_beta=optimize_beta,
allow_infinite_bounds=allow_infinite_bounds
allow_infinite_bounds=allow_infinite_bounds, **predict_params
)

def _more_tags(self):
Expand Down
Loading

0 comments on commit 603b5da

Please sign in to comment.