Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Fix doc link in HTML representation of estimators #2131

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sklearnex/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import sys
import warnings
from abc import ABC

from daal4py.sklearn._utils import (
PatchingConditionsChain as daal4py_PatchingConditionsChain,
Expand Down Expand Up @@ -113,3 +114,19 @@ def get_hyperparameters(self, op):
return estimator_class

return wrap_class


# This abstract class is meant to generate a clickable doc link for classses
# in sklearnex that are not present in scikit-learn. It should be inherited
# before inheriting from a scikit-learn estimator, otherwise will get overriden
# by the estimator's original.
class NonSKLearnAlgorithm(ABC):
@property
def _doc_link_module(self) -> str:
return "sklearnex"

@property
def _doc_link_template(self) -> str:
module_path = ".".join(self.__class__.__module__.split(".")[:-1])
class_name = self.__class__.__name__
return f"https://intel.github.io/scikit-learn-intelex/latest/non-scikit-algorithms.html#{module_path}.{class_name}"
4 changes: 2 additions & 2 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics

from .._device_offload import dispatch
from .._utils import PatchingConditionsChain
from .._utils import NonSKLearnAlgorithm, PatchingConditionsChain

if sklearn_check_version("1.6"):
from sklearn.utils.validation import validate_data
Expand All @@ -38,7 +38,7 @@


@control_n_jobs(decorated_methods=["fit"])
class BasicStatistics(BaseEstimator):
class BasicStatistics(NonSKLearnAlgorithm, BaseEstimator):
"""
Estimator for basic statistics.
Allows to compute basic statistics for provided data.
Expand Down
5 changes: 3 additions & 2 deletions sklearnex/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)

from .._device_offload import dispatch
from .._utils import PatchingConditionsChain
from .._utils import NonSKLearnAlgorithm, PatchingConditionsChain

if sklearn_check_version("1.2"):
from sklearn.utils._param_validation import Interval, StrOptions
Expand All @@ -41,7 +41,8 @@


@control_n_jobs(decorated_methods=["partial_fit", "_onedal_finalize_fit"])
class IncrementalBasicStatistics(BaseEstimator):
# class IncrementalBasicStatistics(NonSKLearnAlgorithm, BaseEstimator):
class IncrementalBasicStatistics(NonSKLearnAlgorithm, BaseEstimator):
"""
Calculates basic statistics on the given data, allows for computation when the data are split into
batches. The user can use ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide
Expand Down
8 changes: 6 additions & 2 deletions sklearnex/covariance/incremental_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from sklearnex import config_context

from .._device_offload import dispatch, wrap_output_data
from .._utils import PatchingConditionsChain, register_hyperparameters
from .._utils import (
NonSKLearnAlgorithm,
PatchingConditionsChain,
register_hyperparameters,
)
from ..metrics import pairwise_distances
from ..utils._array_api import get_namespace

Expand All @@ -47,7 +51,7 @@


@control_n_jobs(decorated_methods=["partial_fit", "fit", "_onedal_finalize_fit"])
class IncrementalEmpiricalCovariance(BaseEstimator):
class IncrementalEmpiricalCovariance(NonSKLearnAlgorithm, BaseEstimator):
"""
Maximum likelihood covariance estimator that allows for the estimation when the data are split into
batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide
Expand Down
10 changes: 8 additions & 2 deletions sklearnex/linear_model/incremental_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from onedal.common.hyperparameters import get_hyperparameters

from .._device_offload import dispatch, wrap_output_data
from .._utils import PatchingConditionsChain, register_hyperparameters
from .._utils import (
NonSKLearnAlgorithm,
PatchingConditionsChain,
register_hyperparameters,
)


@register_hyperparameters(
Expand All @@ -52,7 +56,9 @@
@control_n_jobs(
decorated_methods=["fit", "partial_fit", "predict", "score", "_onedal_finalize_fit"]
)
class IncrementalLinearRegression(MultiOutputMixin, RegressorMixin, BaseEstimator):
class IncrementalLinearRegression(
NonSKLearnAlgorithm, MultiOutputMixin, RegressorMixin, BaseEstimator
):
"""
Trains a linear regression model, allows for computation if the data are split into
batches. The user can use the ``partial_fit`` method to provide a single batch of data or use the ``fit`` method to provide
Expand Down
Loading