Skip to content

Commit

Permalink
FIX: coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Damien-Bouet committed Jul 30, 2024
1 parent b31f9da commit 6d94658
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
34 changes: 8 additions & 26 deletions mapie/futur/split/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from mapie._typing import ArrayLike, NDArray
from mapie.calibrators.utils import check_calibrator
from mapie.conformity_scores import BaseClassificationScore, LACConformityScore
from mapie.conformity_scores import BaseClassificationScore
from mapie.conformity_scores.interface import BaseConformityScore
from mapie.conformity_scores.utils import check_classification_conformity_score
from mapie.estimator.classifier import EnsembleClassifier
Expand Down Expand Up @@ -99,7 +99,7 @@ class SplitCPClassifier(SplitCP):
Examples
--------
>>> import numpy as np
>>> from mapie.futur import SplitCPClassifier
>>> from mapie.futur.split import SplitCPClassifier
>>> np.random.seed(1)
>>> X_train = np.arange(0,400,2).reshape(-1, 1)
>>> y_train = np.array([0]*50 + [1]*50 + [2]*50 + [3]*50)
Expand All @@ -108,11 +108,11 @@ class SplitCPClassifier(SplitCP):
>>> y_pred, y_pis = mapie_reg.predict(X_train)
>>> print(np.round(y_pred[[0, 40, 80, 120]], 2))
[0 0 1 2]
>>> print(np.round(y_pis[[0, 40, 80, 120], :, 0], 2))
[[1. 1. 1. 1.]
[1. 0. 0. 0.]
[0. 1. 0. 0.]
[0. 0. 1. 0.]]
>>> print(y_pis[[0, 40, 80, 120], :, 0])
[[ True True True True]
[ True False False False]
[False True False False]
[False False True False]]
"""
def __init__(
self,
Expand Down Expand Up @@ -205,9 +205,7 @@ def _check_estimator_classification(
check_is_fitted(est)
if not hasattr(est, "classes_"):
raise AttributeError(
"Invalid classifier. "
"Fitted classifier does not contain "
"'classes_' attribute."
"Fitted classifier must contain 'classes_' attribute."
)
return est

Expand All @@ -221,22 +219,6 @@ def _check_fit_parameters(self) -> ClassifierMixin:
self.cv)
return predictor

def _check_calib_conformity_score(
self, conformity_score: Optional[BaseClassificationScore], sym: bool
):
if not sym:
raise ValueError("`sym` argument should be set to `True`"
"in classification")
if conformity_score is None:
return LACConformityScore()
elif isinstance(conformity_score, BaseClassificationScore):
return conformity_score
else:
raise ValueError(
"Invalid conformity_score argument.\n"
"Must be None or a BaseClassificationScore instance."
)

def _check_calibrate_parameters(self) -> Tuple[
BaseClassificationScore, BaseCalibrator
]:
Expand Down
24 changes: 24 additions & 0 deletions mapie/tests/test_futur_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,27 @@ def test_check_conformity_scores_error() -> None:
mapie = SplitCPClassifier()
with pytest.raises(ValueError, match="Invalid conformity scores."):
mapie._check_conformity_scores(np.random.rand(200, 5))


def test_invalid_classifier():
"""
Fitted classifier must contain the ``classes_`` attribute
"""
class Custom(ClassifierMixin):
def __init__(self) -> None:
self.fitted_ = True

def fit():
pass

def predict():
pass

def predict_proba():
pass

invalid_cls = Custom()
mapie = SplitCPClassifier(invalid_cls, cv="prefit", alpha=0.1)
with pytest.raises(AttributeError,
match="Fitted classifier must contain 'classes_' attr"):
mapie.fit(X, y)

0 comments on commit 6d94658

Please sign in to comment.