From 3b845147aabc155a3e082a873d286035c11b9ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Mon, 7 Oct 2024 11:27:10 +0200 Subject: [PATCH 01/10] update README; conftest for rst; workflow build docs; logging --- .coveragerc | 2 - .github/workflows/docs.yml | 30 ++ README.rst | 5 +- docs/Makefile | 3 + docs/conf.py | 5 + examples/RPCA.md | 270 ------------------ examples/benchmark.md | 2 +- examples/tutorials/plot_tuto_benchmark_TS.py | 4 +- .../tutorials/plot_tuto_diffusion_models.py | 22 +- examples/tutorials/plot_tuto_mean_median.py | 2 +- poetry.lock | 12 +- pyproject.toml | 1 + pytest.ini | 4 +- qolmat/benchmark/comparator.py | 15 +- qolmat/benchmark/metrics.py | 12 +- qolmat/imputations/diffusions/ddpms.py | 13 +- qolmat/imputations/em_sampler.py | 10 +- qolmat/imputations/imputers_pytorch.py | 12 +- qolmat/imputations/rpca/rpca_noisy.py | 8 +- qolmat/imputations/softimpute.py | 11 +- qolmat/utils/data.py | 3 + tests/conftest.py | 128 +++++++++ tests/imputations/test_em_sampler.py | 3 - 23 files changed, 257 insertions(+), 320 deletions(-) delete mode 100644 .coveragerc create mode 100644 .github/workflows/docs.yml delete mode 100644 examples/RPCA.md create mode 100644 tests/conftest.py diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 6d420485..00000000 --- a/.coveragerc +++ /dev/null @@ -1,2 +0,0 @@ -[run] -omit = qolmat/_version.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..a85a3783 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,30 @@ +name: Build Docs + +on: + pull_request: + branches: + - main + - dev + +jobs: + build-docs: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Poetry + uses: snok/install-poetry@v1 + with: + version: 1.8.3 + - name: Lock + run: poetry lock --no-update + - name: Install + run: poetry install + - name: Build Docs + run: | + poetry run sphinx-build -b html docs/ _build/html diff --git a/README.rst b/README.rst index a9980de1..1f292ebf 100644 --- a/README.rst +++ b/README.rst @@ -70,14 +70,13 @@ With just these few lines of code, you can see how easy it is to from qolmat.utils import data # load and prepare csv data - df_data = data.get_data("Beijing") columns = ["TEMP", "PRES", "WSPM"] df_data = df_data[columns] df_with_nan = data.add_holes(df_data, ratio_masked=0.2, mean_size=120) # impute and compare - imputer_mean = imputers.ImputerMean(groups=("station",)) + imputer_mean = imputers.ImputerSimple(strategy="mean", groups=("station",)) imputer_interpol = imputers.ImputerInterpolation(method="linear", groups=("station",)) imputer_var1 = imputers.ImputerEM(model="VAR", groups=("station",), method="mle", max_iter_em=50, n_iter_ou=15, dt=1e-3, p=1) dict_imputers = { @@ -90,7 +89,7 @@ With just these few lines of code, you can see how easy it is to dict_imputers, columns, generator_holes = generator_holes, - metrics = ["mae", "wmape", "KL_columnwise", "ks_test", "energy"], + metrics = ["mae", "wmape", "kl_columnwise", "ks_test", "energy"], ) results = comparison.compare(df_with_nan) results.style.highlight_min(color="lightsteelblue", axis=1) diff --git a/docs/Makefile b/docs/Makefile index 629abfa4..fc5a3657 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -37,6 +37,9 @@ clean: -rm -rf examples/tutorials/* -rm -rf generated/* +doctest: + $(SPHINXBUILD) -b doctest . _build/doctest + html: # These two lines make the build a bit more lengthy, and the # the embedding of images more robust diff --git a/docs/conf.py b/docs/conf.py index 2429e591..4d4517dd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -159,6 +159,11 @@ "reference_url": {"qolmat": None}, } +suppress_warnings = ["autosectionlabel.*"] + +# doctest configuration +doctest_test_doctest_blocks = "default" + html_css_files = [ "custom.css", ] diff --git a/examples/RPCA.md b/examples/RPCA.md deleted file mode 100644 index 0a4fbe8e..00000000 --- a/examples/RPCA.md +++ /dev/null @@ -1,270 +0,0 @@ ---- -jupyter: - jupytext: - formats: ipynb,md - text_representation: - extension: .md - format_name: markdown - format_version: '1.3' - jupytext_version: 1.14.4 - kernelspec: - display_name: env_qolmat_dev - language: python - name: env_qolmat_dev ---- - -```python tags=[] -%reload_ext autoreload -%autoreload 2 - -import numpy as np -# import timesynth as ts # package for generating time series - -import matplotlib.pyplot as plt - -import sys - -from math import pi - -from qolmat.utils import utils, plot, data -from qolmat.imputations.rpca.rpca_pcp import RpcaPcp -from qolmat.imputations.rpca.rpca_noisy import RpcaNoisy -from qolmat.imputations.softimpute import SoftImpute -from qolmat.imputations.rpca import rpca_utils -from qolmat.utils.data import generate_artificial_ts -``` - -```python -from qolmat.imputations.imputers import ImputerRpcaNoisy, ImputerRpcaPcp -``` - -**Generate synthetic data** - -```python tags=[] -n_samples = 10000 -periods = [100, 20] -amp_anomalies = 0.5 -ratio_anomalies = 0.05 -amp_noise = 0.1 - -X_true, A_true, E_true = generate_artificial_ts(n_samples, periods, amp_anomalies, ratio_anomalies, amp_noise) - -signal = X_true + A_true + E_true -signal = 10 + signal * 40 - -# Adding missing data -signal[120:180] = np.nan -signal[:20] = np.nan -for i in range(10): - signal[i::365] = np.nan -# signal[80:220] = np.nan -# mask = np.random.choice(len(signal), round(len(signal) / 20)) -# signal[mask] = np.nan - -``` - -```python -import pandas as pd -df = pd.DataFrame({"signal": signal}) -irn = ImputerRpcaPcp(period=100) -df_imp = irn.fit_transform(df) -``` - -```python -plt.plot(df_imp["signal"]) -plt.plot(df["signal"]) - -plt.xlim(0, 200) -``` - -```python tags=[] -fig = plt.figure(figsize=(15, 8)) -ax = fig.add_subplot(4, 1, 1) -ax.title.set_text("Low-rank signal") -plt.plot(X_true) - -ax = fig.add_subplot(4, 1, 2) -ax.title.set_text("Corruption signal") -plt.plot(A_true) - -ax = fig.add_subplot(4, 1, 3) -ax.title.set_text("Noise") -plt.plot(E_true) - -ax = fig.add_subplot(4, 1, 4) -ax.title.set_text("Corrupted signal") -plt.plot(signal) - -plt.show() -``` - - -# Fit RPCA Noisy - - -```python tags=[] -rpca_noisy = RpcaNoisy(tau=1, lam=.4, rank=1, norm="L2") -``` - -```python tags=[] -period = 100 -D = utils.prepare_data(signal, period) -Omega = ~np.isnan(D) -D = utils.linear_interpolation(D) -``` - -```python tags=[] -M, A, L, Q = rpca_noisy.decompose_with_basis(D, Omega) -M2, A2 = rpca_noisy.decompose_on_basis(D, Omega, Q) -``` - -```python tags=[] -M_final = utils.get_shape_original(M, signal.shape) -A_final = utils.get_shape_original(A, signal.shape) -D_final = utils.get_shape_original(D, signal.shape) -signal_imputed = M_final + A_final -``` - -```python tags=[] -fig = plt.figure(figsize=(12, 4)) - -plt.plot(signal_imputed, label="Imputed signal with anomalies") -plt.plot(M_final, label="Imputed signal without anomalies") -plt.plot(A_final, label="Anomalies") -# plt.plot(D_final, label="D") -plt.plot(signal, color="black", label="Original signal") -plt.xlim(0, 400) -plt.legend() -plt.show() -``` - -## PCP RPCA - -```python tags=[] -rpca_pcp = RpcaPcp(max_iterations=1000, lam=.1) -``` - -```python tags=[] -period = 100 -D = utils.prepare_data(signal, period) -Omega = ~np.isnan(D) -D = utils.linear_interpolation(D) -``` - -```python tags=[] -M, A = rpca_pcp.decompose(D, Omega) -``` - -```python tags=[] -M_final = utils.get_shape_original(M, signal.shape) -A_final = utils.get_shape_original(A, signal.shape) -D_final = utils.get_shape_original(D, signal.shape) -# Y_final = utils.get_shape_original(Y, signal.shape) -signal_imputed = M_final + A_final -``` - -```python tags=[] -fig = plt.figure(figsize=(12, 4)) - -plt.plot(signal_imputed, label="Imputed signal with anomalies") -plt.plot(M_final, label="Imputed signal without anomalies") -plt.plot(A_final, label="Anomalies") - -plt.plot(signal, color="black", label="Original signal") -plt.xlim(0, 400) -# plt.gca().twinx() -# plt.plot(Y_final, label="Y") -plt.legend() -plt.show() -``` - -## Soft Impute - -```python tags=[] -imputer = SoftImpute(max_iterations=1000, tau=.1) -``` - -```python tags=[] -period = 100 -D = utils.prepare_data(signal, period) -Omega = ~np.isnan(D) -D = utils.linear_interpolation(D) -``` - -```python tags=[] -M, A = imputer.decompose(D, Omega) -``` - -```python tags=[] -M_final = utils.get_shape_original(M, signal.shape) -A_final = utils.get_shape_original(A, signal.shape) -D_final = utils.get_shape_original(D, signal.shape) -# Y_final = utils.get_shape_original(Y, signal.shape) -signal_imputed = M_final + A_final -``` - -```python tags=[] -fig = plt.figure(figsize=(12, 4)) - -plt.plot(signal_imputed, label="Imputed signal with anomalies") -plt.plot(M_final, label="Imputed signal without anomalies") -plt.plot(A_final, label="Anomalies") - -plt.plot(signal, color="black", label="Original signal") -plt.xlim(0, 400) -plt.legend() -plt.show() -``` - -## Temporal RPCA - -```python -%%time -rpca_noisy = RpcaNoisy(tau=1, lam=0.4, rank=2, norm="L2") -M, A = rpca_noisy.decompose(D, Omega) -# imputed = X -``` - -```python tags=[] -fig = plt.figure(figsize=(12, 4)) - -plt.plot(signal_imputed, label="Imputed signal with anomalies") -plt.plot(M_final, label="Imputed signal without anomalies") -plt.plot(A_final, label="Anomalies") - -plt.plot(signal, color="black", label="Original signal") -plt.xlim(0, 400) -# plt.gca().twinx() -# plt.plot(Y_final, label="Y") -plt.legend() -plt.show() -``` - -# EM VAR(p) - -```python -from qolmat.imputations import em_sampler -``` - -```python -p = 1 -model = em_sampler.VARpEM(method="mle", max_iter_em=10, n_iter_ou=512, dt=1e-1, p=p) -``` - -```python -D = signal.reshape(-1, 1) -M_final = model.fit_transform(D) -``` - -```python -fig = plt.figure(figsize=(12, 4)) -plt.plot(signal_imputed, label="Imputed signal with anomalies") -plt.plot(M_final, label="Imputed signal without anomalies") -plt.xlim(0, 400) -plt.legend() -plt.show() -``` - -```python - -``` diff --git a/examples/benchmark.md b/examples/benchmark.md index e22b8991..b131bb00 100644 --- a/examples/benchmark.md +++ b/examples/benchmark.md @@ -185,7 +185,7 @@ Concretely, the comparator takes as input a dataframe to impute, a proportion of Note these metrics compute reconstruction errors; it tells nothing about the distances between the "true" and "imputed" distributions. ```python tags=[] -metrics = ["mae", "wmape", "KL_columnwise", "frechet"] +metrics = ["mae", "wmape", "kl_columnwise", "frechet"] comparison = comparator.Comparator( dict_imputers, cols_to_impute, diff --git a/examples/tutorials/plot_tuto_benchmark_TS.py b/examples/tutorials/plot_tuto_benchmark_TS.py index a92398cb..da3600cd 100644 --- a/examples/tutorials/plot_tuto_benchmark_TS.py +++ b/examples/tutorials/plot_tuto_benchmark_TS.py @@ -124,7 +124,7 @@ dict_imputers, cols_to_impute, generator_holes=generator_holes, - metrics=["mae", "wmape", "KL_columnwise", "wasserstein_columnwise"], + metrics=["mae", "wmape", "kl_columnwise", "wasserstein_columnwise"], max_evals=10, ) results = comparison.compare(df) @@ -133,7 +133,7 @@ # %% # We have considered four metrics for comparison. # ``mae`` and ``wmape`` are point-wise metrics, -# while ``KL_columnwise`` and ``wasserstein_columnwise`` are metrics +# while ``kl_columnwise`` and ``wasserstein_columnwise`` are metrics # that compare distributions. # Since we treat time series with strong seasonal patterns, imputation # on residuals works very well. diff --git a/examples/tutorials/plot_tuto_diffusion_models.py b/examples/tutorials/plot_tuto_diffusion_models.py index 8275f3bb..6715aae1 100644 --- a/examples/tutorials/plot_tuto_diffusion_models.py +++ b/examples/tutorials/plot_tuto_diffusion_models.py @@ -5,7 +5,7 @@ In this tutorial, we show how to use :class:`~qolmat.imputations.diffusions.ddpms.TabDDPM` and :class:`~qolmat.imputations.diffusions.ddpms.TsDDPM` classes. """ - +import logging import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -15,6 +15,12 @@ from qolmat.imputations.imputers_pytorch import ImputerDiffusion from qolmat.utils import data +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + # %% # 1. Time-series data # --------------------------------------------------------------- @@ -31,8 +37,7 @@ [df_data.index.levels[0], pd.to_datetime(df_data.index.levels[1])] ) -print("Number of nan at each column:") -print(df_data.isna().sum()) +logging.info(f"Number of nan at each column: {df_data.isna().sum()}") # %% # 2. Hyperparameters for the wapper ImputerDiffusion @@ -77,14 +82,15 @@ # %% # We can see the architecture of the TabDDPM with ``get_summary_architecture()`` -print(tabddpm.get_summary_architecture()) +logging.info(tabddpm.get_summary_architecture()) # %% # We also get the summary of the training progress with ``get_summary_training()`` summary = tabddpm.get_summary_training() -print(f"Performance metrics: {list(summary.keys())}") +logging.info(f"Performance metrics: {list(summary.keys())}") + metric = "mean_absolute_error" metric_scores = summary[metric] @@ -151,7 +157,7 @@ # * ``dim_embedding``: dimension of hidden layers in residual blocks (``int = 128``) # # Let see an example below. We can observe that a large ``num_sampling`` generally improves -# reconstruction errors (mae) but increases distribution distance (KL_columnwise). +# reconstruction errors (mae) but increases distribution distance (kl_columnwise). dict_imputers = { "num_sampling=5": ImputerDiffusion( @@ -166,7 +172,7 @@ dict_imputers, selected_columns=df_data.columns, generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2), - metrics=["mae", "KL_columnwise"], + metrics=["mae", "kl_columnwise"], ) results = comparison.compare(df_data) @@ -220,7 +226,7 @@ dict_imputers, selected_columns=df_data.columns, generator_holes=missing_patterns.UniformHoleGenerator(n_splits=2), - metrics=["mae", "KL_columnwise"], + metrics=["mae", "kl_columnwise"], ) results = comparison.compare(df_data) diff --git a/examples/tutorials/plot_tuto_mean_median.py b/examples/tutorials/plot_tuto_mean_median.py index 33c36db2..5d59bd6d 100644 --- a/examples/tutorials/plot_tuto_mean_median.py +++ b/examples/tutorials/plot_tuto_mean_median.py @@ -77,7 +77,7 @@ imputer_median = imputers.ImputerSimple(strategy="median") dict_imputers = {"mean": imputer_mean, "median": imputer_median} -metrics = ["mae", "wmape", "KL_columnwise"] +metrics = ["mae", "wmape", "kl_columnwise"] # %% # Concretely, the comparator takes as input a dataframe to impute, diff --git a/poetry.lock b/poetry.lock index 32d2d7e9..efa6e0cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1144,13 +1144,13 @@ test = ["flufl.flake8", "importlib-resources (>=1.3)", "jaraco.test (>=5.4)", "p [[package]] name = "importlib-resources" -version = "6.4.4" +version = "6.4.5" description = "Read resources from Python packages" optional = false python-versions = ">=3.8" files = [ - {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, - {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, + {file = "importlib_resources-6.4.5-py3-none-any.whl", hash = "sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717"}, + {file = "importlib_resources-6.4.5.tar.gz", hash = "sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065"}, ] [package.dependencies] @@ -1916,7 +1916,7 @@ files = [ name = "markdown" version = "3.7" description = "Python implementation of John Gruber's Markdown." -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"}, @@ -4014,7 +4014,7 @@ sphinx = ">=1.8.3" name = "sphinx-markdown-tables" version = "0.0.17" description = "A Sphinx extension for rendering tables written in markdown" -optional = true +optional = false python-versions = "*" files = [ {file = "sphinx-markdown-tables-0.0.17.tar.gz", hash = "sha256:6bc6d3d400eaccfeebd288446bc08dd83083367c58b85d40fe6c12d77ef592f1"}, @@ -4697,4 +4697,4 @@ tests = ["typed-ast"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "e1b2b86ae087b58e7334855fd8a3f74f331e7046f949135ff591bf2851f28fce" +content-hash = "009b72db9c810af43afc035cfbb380f122dbf69a37700ae017a06ed46cc36e4e" diff --git a/pyproject.toml b/pyproject.toml index 3c2bc7aa..fa42eaf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ numpydoc = "1.1.0" sphinx = "4.3.2" sphinx-gallery = "0.10.1" sphinx_rtd_theme = "1.0.0" +sphinx_markdown_tables = "0.0.17" [tool.poetry.extras] tests = ["typed-ast"] diff --git a/pytest.ini b/pytest.ini index 2e3719d8..cbf59873 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,4 @@ [pytest] -addopts = --cov=qolmat +addopts = --cov=qolmat --doctest-glob="*.rst" --doctest-modules +testpaths = tests +norecursedirs = _build diff --git a/qolmat/benchmark/comparator.py b/qolmat/benchmark/comparator.py index 4fed2e9e..413e944b 100644 --- a/qolmat/benchmark/comparator.py +++ b/qolmat/benchmark/comparator.py @@ -1,5 +1,6 @@ """Script for comparator.""" +import logging from typing import Any, Dict, List, Optional import numpy as np @@ -8,6 +9,12 @@ from qolmat.benchmark import hyperparameters, metrics from qolmat.benchmark.missing_patterns import _HoleGenerator +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + class Comparator: """Comparator class. @@ -39,7 +46,7 @@ def __init__( dict_models: Dict[str, Any], selected_columns: List[str], generator_holes: _HoleGenerator, - metrics: List = ["mae", "wmape", "KL_columnwise"], + metrics: List = ["mae", "wmape", "kl_columnwise"], dict_config_opti: Optional[Dict[str, Any]] = {}, metric_optim: str = "mse", max_evals: int = 10, @@ -167,13 +174,13 @@ def compare( dict_config_opti_imputer = self.dict_config_opti.get(name, {}) try: - print(f"Testing model: {name}...", end="") + logging.info(f"Testing model: {name}...") dict_errors[name] = self.evaluate_errors_sample( imputer, df, dict_config_opti_imputer, self.metric_optim ) - print("done.") + logging.info("done.") except Exception as excp: - print( + logging.info( f"Error while testing {name} of type " f"{type(imputer).__name__}!" ) diff --git a/qolmat/benchmark/metrics.py b/qolmat/benchmark/metrics.py index b8af8667..ef499f32 100644 --- a/qolmat/benchmark/metrics.py +++ b/qolmat/benchmark/metrics.py @@ -876,10 +876,10 @@ def frechet_distance( """Compute Frechet distance computed using a pattern decomposition. Several variant are implemented: - - the `single` method relies on a single estimation of the means and + i) the `single` method relies on a single estimation of the means and covariance matrix. It is relevent for MCAR data. - - the `pattern`method relies on the aggregation of the estimated distance - between each pattern. It is relevent for MAR data. + ii) the `pattern` method relies on the aggregation of the estimated + distance between each pattern. It is relevent for MAR data. Parameters ---------- @@ -1200,9 +1200,9 @@ def get_metric( "wmape": weighted_mean_absolute_percentage_error, "accuracy": accuracy, "wasserstein_columnwise": dist_wasserstein, - "KL_columnwise": partial(kl_divergence, method="columnwise"), - "KL_gaussian": partial(kl_divergence, method="gaussian"), - "KS_test": kolmogorov_smirnov_test, + "kl_columnwise": partial(kl_divergence, method="columnwise"), + "kl_gaussian": partial(kl_divergence, method="gaussian"), + "ks_test": kolmogorov_smirnov_test, "correlation_diff": ( mean_difference_correlation_matrix_numerical_features ), diff --git a/qolmat/imputations/diffusions/ddpms.py b/qolmat/imputations/diffusions/ddpms.py index 4f8728e9..3063b9a9 100644 --- a/qolmat/imputations/diffusions/ddpms.py +++ b/qolmat/imputations/diffusions/ddpms.py @@ -1,5 +1,6 @@ """Script for DDPM classes.""" +import logging import time from datetime import timedelta from typing import Callable, Dict, List, Tuple, Union @@ -21,6 +22,12 @@ ) from qolmat.imputations.diffusions.utils import get_num_params +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + class TabDDPM: """Tab DDPM. @@ -184,7 +191,7 @@ def _print_valid(self, epoch: int, time_duration: float) -> None: self.time_durations.append(time_duration) print_step = 1 if int(self.epochs / 10) == 0 else int(self.epochs / 10) if self.print_valid and epoch == 0: - print( + logging.info( f"Num params of {self.__class__.__name__}: {self.num_params}" ) if self.print_valid and epoch % print_step == 0: @@ -200,7 +207,7 @@ def _print_valid(self, epoch: int, time_duration: float) -> None: string_valid += ( f" | remaining {timedelta(seconds=remaining_duration)}" ) - print(string_valid) + logging.info(string_valid) def _impute(self, x: np.ndarray, x_mask_obs: np.ndarray) -> np.ndarray: """Impute data array. @@ -763,7 +770,7 @@ def _process_data( if is_training: if self.is_rolling: if self.print_valid: - print( + logging.info( "Preprocessing data with sliding window " "(pandas.DataFrame.rolling) " "can require more times than usual. " diff --git a/qolmat/imputations/em_sampler.py b/qolmat/imputations/em_sampler.py index eba85062..337a1507 100644 --- a/qolmat/imputations/em_sampler.py +++ b/qolmat/imputations/em_sampler.py @@ -1,5 +1,6 @@ """Script for EM imputation.""" +import logging import warnings from abc import abstractmethod from typing import Dict, List, Literal, Tuple, Union @@ -11,9 +12,14 @@ from sklearn import utils as sku from sklearn.base import BaseEstimator, TransformerMixin -# from typing_extensions import Self from qolmat.utils import utils +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + def _conjugate_gradient(A: NDArray, X: NDArray, mask: NDArray) -> NDArray: """Compute conjugate gradient. @@ -436,7 +442,7 @@ def fit_X(self, X: NDArray) -> None: self.update_criteria_stop(X) if self._check_convergence(): if self.verbose: - print(f"EM converged after {iter_em} iterations.") + logging.info(f"EM converged after {iter_em} iterations.") break self.dict_criteria_stop = {key: [] for key in self.dict_criteria_stop} diff --git a/qolmat/imputations/imputers_pytorch.py b/qolmat/imputations/imputers_pytorch.py index aff2b32f..3b8b0ec7 100644 --- a/qolmat/imputations/imputers_pytorch.py +++ b/qolmat/imputations/imputers_pytorch.py @@ -1,5 +1,6 @@ """Script for pytroch imputers.""" +import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -24,6 +25,13 @@ raise PyTorchExtraNotInstalled +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + class ImputerRegressorPyTorch(ImputerRegressor): """Imputer regressor based on PyTorch. @@ -109,7 +117,7 @@ def _fit_estimator( loss.backward() optimizer.step() if (epoch + 1) % 10 == 0: - print( + logging.info( f"Epoch [{epoch + 1}/{self.epochs}], " f"Loss: {loss.item():.4f}" ) @@ -231,7 +239,7 @@ def fit(self, X: NDArray, y: NDArray) -> "Autoencoder": loss.backward() optimizer.step() if (epoch + 1) % 10 == 0: - print( + logging.info( f"Epoch [{epoch + 1}/{self.epochs}], " f"Loss: {loss.item():.4f}" ) diff --git a/qolmat/imputations/rpca/rpca_noisy.py b/qolmat/imputations/rpca/rpca_noisy.py index ae59ae0a..5a87958b 100644 --- a/qolmat/imputations/rpca/rpca_noisy.py +++ b/qolmat/imputations/rpca/rpca_noisy.py @@ -260,13 +260,13 @@ def minimise_loss( Tuple A tuple containing the following elements: - M : np.ndarray - Low-rank signal matrix of shape (m, n). + Low-rank signal matrix of shape (m, n). - A : np.ndarray - Anomalies matrix of shape (m, n). + Anomalies matrix of shape (m, n). - L : np.ndarray - Basis unitary array of shape (m, rank). + Basis unitary array of shape (m, rank). - Q : np.ndarray - Basis unitary array of shape (rank, n). + Basis unitary array of shape (rank, n). Raises ------ diff --git a/qolmat/imputations/softimpute.py b/qolmat/imputations/softimpute.py index 72d3a8c4..62de3d26 100644 --- a/qolmat/imputations/softimpute.py +++ b/qolmat/imputations/softimpute.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging import warnings from typing import Optional, Tuple, Union @@ -13,6 +14,12 @@ from qolmat.imputations.rpca import rpca_utils from qolmat.utils import utils +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + class SoftImpute(BaseEstimator, TransformerMixin): """Class for the Rank Restricted Soft SVD algorithm. @@ -170,9 +177,9 @@ def decompose(self, X: NDArray, Omega: NDArray) -> Tuple[NDArray, NDArray]: # Step 4 : Stopping upon convergence ratio = SoftImpute._check_convergence(U_old, D_old, V_old, U, D, V) if self.verbose: - print(f"Iteration {iter_}: ratio = {round(ratio, 4)}") + logging.info(f"Iteration {iter_}: ratio = {round(ratio, 4)}") if ratio < self.tolerance: - print( + logging.info( f"Convergence reached at iteration {iter_} " f"with ratio = {round(ratio, 4)}" ) diff --git a/qolmat/utils/data.py b/qolmat/utils/data.py index eeef323a..a96aef32 100644 --- a/qolmat/utils/data.py +++ b/qolmat/utils/data.py @@ -250,6 +250,9 @@ def get_data( ) df = pd.read_csv(csv_url, index_col=0) return df + elif name_data == "conductor": + df = read_csv_local("conductors") + return df elif name_data == "Monach_weather": urllink = os.path.join( url_zenodo, "4654822/files/weather_dataset.zip?download=1" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ce7a935b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,128 @@ +"""Conftest for pytest-rst.""" + +import ast +import logging +from pathlib import Path +from typing import Any, List + +import docutils.core # type: ignore +import docutils.nodes # type: ignore +import pytest + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +RST_FILES_TO_TEST = [ + Path(__file__).parent.parent / "README.rst", +] + + +def extract_python_blocks(content) -> List[str]: + """Extract all Python code blocks from the RST content. + + This function parses the provided RST content and extracts all the + Python code blocks marked by `.. code-block:: python` + or `.. code:: python`. The extracted code blocks are returned as a + list of strings. For isntance, given an RST content with a + Python code block: + .. code-block:: python + + print("Hello, world!") + + This function would return: ["print('Hello, world!')"] + + Parameters + ---------- + content : str + The reStructuredText (RST) content to be parsed and searched for + Python code blocks. + + Returns + ------- + List[str] + A list of strings, where each string is a Python code block extracted + from the RST content. + + """ + document = docutils.core.publish_doctree(content) + code_blocks = [ + node.astext() + for node in document.traverse(docutils.nodes.literal_block) + if "python" in node.get("classes", []) + ] + return code_blocks + + +@pytest.hookimpl(tryfirst=True) +def pytest_sessionstart(session: Any) -> None: + """Run tests (hook) on specified .rst files at pytest session start. + + This function reads through a list of predefined .rst files, + extracts Python code blocks, and ensures that the code is syntactically + valid and that all necessary imports work correctly. + The function will scan each file listed in `RST_FILES_TO_TEST`. + + This function is invoked automatically by pytest when the session starts. + No manual invocation is needed. + + Parameters + ---------- + session : Any + The pytest session object. This hook is automatically called by pytest + at the start of the session. It is not used in this function but is + required by the pytest hook mechanism. + + Raises + ------ + pytest.fail + Raised if there is a syntax error in any of the Python code blocks, + or if there is an import failure for any of the modules used + in the code blocks. + + """ + for rst_file in RST_FILES_TO_TEST: + if rst_file.exists(): + logging.info(f"Testing Python code in {rst_file}.") + with open(rst_file) as f: + content = f.read() + code_blocks = extract_python_blocks(content) + + for i, code_block in enumerate(code_blocks): + try: + tree = ast.parse(code_block) + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for name in node.names: + try: + __import__(name.name) + except ImportError as e: + pytest.fail( + f"Cannot import {name.name} in " + f"{rst_file}: {str(e)}" + ) + elif isinstance(node, ast.ImportFrom): + if node.module: + try: + __import__(node.module) # noqa + except ImportError as e: + pytest.fail( + f"Cannot import {node.module} in " + f"{rst_file}: {str(e)}." + ) + else: + pytest.fail( + f"Module name is None in {rst_file} " + "in ImportFrom statement." + ) + except SyntaxError as e: + pytest.fail( + "Invalid Python syntax in code block " + f"{i + 1} in {rst_file}: " + f"\n{code_block}\nError: {str(e)}" + ) + else: + logging.info(f"File {rst_file} does not exist, skippin...") diff --git a/tests/imputations/test_em_sampler.py b/tests/imputations/test_em_sampler.py index 21e2ffd0..188bb692 100644 --- a/tests/imputations/test_em_sampler.py +++ b/tests/imputations/test_em_sampler.py @@ -242,7 +242,6 @@ def test_sample_ou_2d(model): alpha = 0.01 q_alpha = scipy.stats.norm.ppf(1 - alpha / 2) - print(mean_est, "vs", mean_theo) assert abs(mean_est - mean_theo) < np.sqrt(var_theo / n_samples) * q_alpha ratio_inf = scipy.stats.chi2.ppf(alpha / 2, n_samples) / (n_samples - 1) @@ -252,8 +251,6 @@ def test_sample_ou_2d(model): ratio = var_est / var_theo - print(var_est, "vs", var_theo) - print(ratio_inf, "<", ratio, "<", ratio_sup) assert ratio_inf <= ratio assert ratio <= ratio_sup From d2b157a2ebdacbd6fa14c5ad83144bf9104237ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Mon, 7 Oct 2024 16:00:40 +0200 Subject: [PATCH 02/10] UTF-8 encoding --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ce7a935b..719eaaa3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -87,7 +87,7 @@ def pytest_sessionstart(session: Any) -> None: for rst_file in RST_FILES_TO_TEST: if rst_file.exists(): logging.info(f"Testing Python code in {rst_file}.") - with open(rst_file) as f: + with open(rst_file, encoding="utf-8") as f: content = f.read() code_blocks = extract_python_blocks(content) From b8308740f771160dcf077993ec636e7cf7182f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:15:26 +0200 Subject: [PATCH 03/10] execute code in rst files and improve unique workflow for unit tests and build docs --- .github/workflows/docs.yml | 30 ----------------- .github/workflows/test.yml | 68 ++++++++++++++++++++++++++++++++++++-- tests/conftest.py | 47 +++++++++----------------- 3 files changed, 80 insertions(+), 65 deletions(-) delete mode 100644 .github/workflows/docs.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml deleted file mode 100644 index a85a3783..00000000 --- a/.github/workflows/docs.yml +++ /dev/null @@ -1,30 +0,0 @@ -name: Build Docs - -on: - pull_request: - branches: - - main - - dev - -jobs: - build-docs: - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: Python - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Poetry - uses: snok/install-poetry@v1 - with: - version: 1.8.3 - - name: Lock - run: poetry lock --no-update - - name: Install - run: poetry install - - name: Build Docs - run: | - poetry run sphinx-build -b html docs/ _build/html diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d737fc7f..c4f6746d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,32 +1,55 @@ -name: Unit tests +name: Unit tests and docs generation on: push: branches: - "**" pull_request: + branches: + - "**" types: [opened, synchronize, reopened, ready_for_review] workflow_dispatch: jobs: check: if: github.event.pull_request.draft == false - runs-on: ${{matrix.os}} + runs-on: ubuntu-latest strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: [3.8, 3.9, 3.10, 3.11] + include: + - os: ubuntu-latest + python-version: 3.11 defaults: run: shell: bash -l {0} steps: + - name: Set OS and Python version + id: set-vars + run: | + if [[ "${{ github.ref }}" == "refs/heads/main" || "${{ github.ref }}" == "refs/heads/dev" ]]; then + echo "os-matrix=ubuntu-latest,windows-latest" >> $GITHUB_ENV + echo "python-matrix=3.8,3.9,3.10,3.11" >> $GITHUB_ENV + else + echo "os-matrix=ubuntu-latest" >> $GITHUB_ENV + echo "python-matrix=3.11" >> $GITHUB_ENV - name: Checkout uses: actions/checkout@v3 - name: Python uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} + - name: Cache Poetry + uses: actions/cache@v3 + with: + path: | + ~/.cache/pypoetry + ~/.cache/pip + key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} + restore-keys: | + ${{ runner.os }}-poetry-${{ matrix.python-version }}- - name: Poetry uses: snok/install-poetry@v1 with: @@ -41,3 +64,42 @@ jobs: uses: codecov/codecov-action@v3 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + docs: + runs-on: ubuntu-latest + needs: check + if: github.event_name == 'push' || github.event_name == 'pull_request' + + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Cache Poetry + uses: actions/cache@v3 + with: + path: | + ~/.cache/pypoetry + ~/.cache/pip + key: ${{ runner.os }}-poetry-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }} + restore-keys: | + ${{ runner.os }}-poetry-${{ matrix.python-version }}- + - name: Poetry + uses: snok/install-poetry@v1 + with: + version: 1.8.3 + - name: Lock + run: poetry lock --no-update + - name: Install + run: poetry install + - name: Check Changed Files + id: changed-files + run: | + git fetch origin ${{ github.base_ref }} --depth=1 + git diff --name-only origin/${{ github.base_ref }} > changed_files.txt + - name: Build Docs + if: contains(fromJSON('["docs/", ".rst"]').join(','), fromJSON('["${{ steps.changed-files.outputs.files }}"]').join(',')) + run: | + poetry run sphinx-build -b html docs/ _build/html diff --git a/tests/conftest.py b/tests/conftest.py index 719eaaa3..61ada23d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,8 +62,7 @@ def pytest_sessionstart(session: Any) -> None: """Run tests (hook) on specified .rst files at pytest session start. This function reads through a list of predefined .rst files, - extracts Python code blocks, and ensures that the code is syntactically - valid and that all necessary imports work correctly. + extracts Python code blocks, concatenates them, and executes the code. The function will scan each file listed in `RST_FILES_TO_TEST`. This function is invoked automatically by pytest when the session starts. @@ -90,39 +89,23 @@ def pytest_sessionstart(session: Any) -> None: with open(rst_file, encoding="utf-8") as f: content = f.read() code_blocks = extract_python_blocks(content) + code = "\n".join(code_blocks) - for i, code_block in enumerate(code_blocks): + if code: try: - tree = ast.parse(code_block) - for node in ast.walk(tree): - if isinstance(node, ast.Import): - for name in node.names: - try: - __import__(name.name) - except ImportError as e: - pytest.fail( - f"Cannot import {name.name} in " - f"{rst_file}: {str(e)}" - ) - elif isinstance(node, ast.ImportFrom): - if node.module: - try: - __import__(node.module) # noqa - except ImportError as e: - pytest.fail( - f"Cannot import {node.module} in " - f"{rst_file}: {str(e)}." - ) - else: - pytest.fail( - f"Module name is None in {rst_file} " - "in ImportFrom statement." - ) + tree = ast.parse(code) + exec(code, globals(), locals()) except SyntaxError as e: pytest.fail( - "Invalid Python syntax in code block " - f"{i + 1} in {rst_file}: " - f"\n{code_block}\nError: {str(e)}" + "Syntax error in code block " + f"in {rst_file}: \n{code}\nError: {str(e)}." ) + except Exception as e: + pytest.fail( + f"Error while executing the code in {rst_file}: " + f"{str(e)}." + ) + else: + logging.info(f"No Python code blocks found in {rst_file}.") else: - logging.info(f"File {rst_file} does not exist, skippin...") + logging.info(f"File {rst_file} does not exist, skipping.") From f2b350a6408ab87a6e29cd00acf9dd08aafe47f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:20:29 +0200 Subject: [PATCH 04/10] fix syntax GitHub Actions --- .github/workflows/test.yml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c4f6746d..52097473 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,9 +18,6 @@ jobs: matrix: os: [ubuntu-latest, windows-latest] python-version: [3.8, 3.9, 3.10, 3.11] - include: - - os: ubuntu-latest - python-version: 3.11 defaults: run: shell: bash -l {0} @@ -29,7 +26,7 @@ jobs: - name: Set OS and Python version id: set-vars run: | - if [[ "${{ github.ref }}" == "refs/heads/main" || "${{ github.ref }}" == "refs/heads/dev" ]]; then + if [[ "${GITHUB_REF}" == "refs/heads/main" || "${GITHUB_REF}" == "refs/heads/dev" ]]; then echo "os-matrix=ubuntu-latest,windows-latest" >> $GITHUB_ENV echo "python-matrix=3.8,3.9,3.10,3.11" >> $GITHUB_ENV else From 7c16a8075625af6b897f42689a381bb4e9e09f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:23:07 +0200 Subject: [PATCH 05/10] fix syntax GitHub Actions --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 52097473..5a94ae9b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -32,6 +32,7 @@ jobs: else echo "os-matrix=ubuntu-latest" >> $GITHUB_ENV echo "python-matrix=3.11" >> $GITHUB_ENV + fi - name: Checkout uses: actions/checkout@v3 - name: Python From 509aa866e2ca6efcd2bb7361e517c7dd2aaf7601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:27:59 +0200 Subject: [PATCH 06/10] python-version in matrix --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5a94ae9b..735a450e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.8, 3.9, 3.10, 3.11] + python-version: ["3.8", "3.9", "3.10", "3.11"] defaults: run: shell: bash -l {0} From ead88c2832b4c02afe60153bcd59a22428ee7d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:52:07 +0200 Subject: [PATCH 07/10] fix ambiguous argument 'origin/': --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 735a450e..ddf65ef1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -95,7 +95,7 @@ jobs: - name: Check Changed Files id: changed-files run: | - git fetch origin ${{ github.base_ref }} --depth=1 + git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} --depth=1 git diff --name-only origin/${{ github.base_ref }} > changed_files.txt - name: Build Docs if: contains(fromJSON('["docs/", ".rst"]').join(','), fromJSON('["${{ steps.changed-files.outputs.files }}"]').join(',')) From dce6883b3bd499760cf06c85fcdcebf9d7d2b8c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Tue, 8 Oct 2024 15:54:16 +0200 Subject: [PATCH 08/10] fix reference for get diff --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ddf65ef1..dc6f18b2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -96,7 +96,7 @@ jobs: id: changed-files run: | git fetch origin ${{ github.base_ref }}:${{ github.base_ref }} --depth=1 - git diff --name-only origin/${{ github.base_ref }} > changed_files.txt + git diff --name-only ${{ github.base_ref }} > changed_files.txt - name: Build Docs if: contains(fromJSON('["docs/", ".rst"]').join(','), fromJSON('["${{ steps.changed-files.outputs.files }}"]').join(',')) run: | From c145710a79e4467962f97ee794ed919572ac3948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Wed, 9 Oct 2024 13:19:30 +0200 Subject: [PATCH 09/10] multiple OS and remove redundant conditions --- .github/workflows/test.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dc6f18b2..8688747e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -13,7 +13,7 @@ on: jobs: check: if: github.event.pull_request.draft == false - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest, windows-latest] @@ -66,7 +66,6 @@ jobs: docs: runs-on: ubuntu-latest needs: check - if: github.event_name == 'push' || github.event_name == 'pull_request' steps: - name: Checkout From bc9b4c8da0c9b1fc86130c09a9e552888b87ae43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=B4ng-Lan=20Botterman?= Date: Mon, 14 Oct 2024 18:31:13 +0200 Subject: [PATCH 10/10] set mypy and pytest in .toml --- examples/tutorials/plot_tuto_mcar.py | 6 +++++- mypy.ini | 20 -------------------- pyproject.toml | 26 +++++++++++++++++++++++++- pytest.ini | 4 ---- 4 files changed, 30 insertions(+), 26 deletions(-) delete mode 100644 mypy.ini delete mode 100644 pytest.ini diff --git a/examples/tutorials/plot_tuto_mcar.py b/examples/tutorials/plot_tuto_mcar.py index fbbd0587..0b47336f 100644 --- a/examples/tutorials/plot_tuto_mcar.py +++ b/examples/tutorials/plot_tuto_mcar.py @@ -376,7 +376,11 @@ # %% pklm_test = PKLMTest(random_state=rng, compute_partial_p_values=True) -p_value, partial_p_values = pklm_test.test(df_nan) +result = pklm_test.test(df_nan) +if isinstance(result, tuple): + p_value, partial_p_values = result +else: + p_value = result print(f"The p-value of the PKLM test is: {p_value:.2%}") # %% diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index 3e2b7028..00000000 --- a/mypy.ini +++ /dev/null @@ -1,20 +0,0 @@ -[mypy] -python_version = 3.10 -ignore_missing_imports = True -disable_error_code = union-attr -#|qolmat/benchmark/missing_patterns.py - - -[mypy-sklearn.*] -ignore_errors = True - -[mypy-doc.*] -#ignore_errors = True - -[mypy-matplotlib.*] -ignore_missing_imports = True - -[mypy-numpy.*] -ignore_missing_imports = True -ignore_errors = True -# disable_error_code = attr-defined diff --git a/pyproject.toml b/pyproject.toml index fa42eaf8..a1f2481f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,8 +102,32 @@ targets = ["qolmat"] [tool.mypy] pretty = true strict = false -python_version = ">=3.8.1,<3.12" +python_version = "3.10" ignore_missing_imports = true +disable_error_code = ["union-attr"] +exclude = "docs" + +[[tool.mypy.overrides]] +module = "sklearn.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "matplotlib.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "numpy.*" +ignore_missing_imports = true +ignore_errors = true + +[[tool.mypy.overrides]] +module = "yaml" +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--cov=qolmat --doctest-glob=*.rst --doctest-modules" +testpaths = ["tests"] +norecursedirs = ["_build"] [tool.ruff] line-length = 79 diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index cbf59873..00000000 --- a/pytest.ini +++ /dev/null @@ -1,4 +0,0 @@ -[pytest] -addopts = --cov=qolmat --doctest-glob="*.rst" --doctest-modules -testpaths = tests -norecursedirs = _build