Skip to content

Commit

Permalink
Merge pull request #167 from decargroup/bugfix/166-pykoop-is-slow-to-…
Browse files Browse the repository at this point in the history
…import

`pykoop` is slow to import
  • Loading branch information
sdahdah authored Jan 12, 2024
2 parents 02ccb15 + c8da7c4 commit 25c42a8
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
31 changes: 17 additions & 14 deletions pykoop/koopman_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas
import sklearn.base
import sklearn.exceptions
import sklearn.metrics
Expand Down Expand Up @@ -3834,14 +3833,13 @@ def _weights_from_data_matrix(
return weights


def _extract_feature_names(
X: Union[np.ndarray, pandas.DataFrame]) -> Optional[np.ndarray]:
def _extract_feature_names(X: Any) -> Optional[np.ndarray]:
"""Extract feature names from input array.
Parameters
----------
X : Union[np.ndarray, pandas.DataFrame]
Input array.
X : Any
Input array, either ``np.ndarray`` or ``pandas.DataFrame``.
Returns
-------
Expand All @@ -3853,13 +3851,18 @@ def _extract_feature_names(
ValueError
If feature names are not strings.
"""
if isinstance(X, pandas.DataFrame):
for name in X.columns:
if not isinstance(name, str):
log.warning(
'Feature names must all be strings. When ``scikit-learn`` '
'v1.2 comes out this will be upgraded to an exception.')
return None
return np.asarray(X.columns, dtype=object)
else:
if isinstance(X, np.ndarray):
return None
else:
try:
for name in X.columns:
if not isinstance(name, str):
log.warning('Feature names must all be strings. When '
'``scikit-learn`` v1.2 comes out this will be '
'upgraded to an exception.')
return None
return np.asarray(X.columns, dtype=object)
except AttributeError:
return None
except TypeError:
return None
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Requirements to install ``pykoop``, run all unit tests, and generate docs
-r requirements.txt

pandas
pytest
pytest-regressions
sphinx
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ numpy>=1.21.0
scipy>=1.7.0
scikit-learn>=1.0.0
PICOS>=2.4.0
pandas>=1.3.1
optht>=0.2.0
Deprecated>=1.2.13
matplotlib>=3.5.1
Expand Down

0 comments on commit 25c42a8

Please sign in to comment.