From d8f1f24aa4874ec10eb82f350ed06eddcd6d1cc3 Mon Sep 17 00:00:00 2001 From: Zhiyi Wu Date: Sun, 24 Dec 2023 19:58:38 +0000 Subject: [PATCH] update --- CHANGES | 4 ++++ environment.yml | 1 + src/alchemlyb/preprocessing/subsampling.py | 9 +++++++++ src/alchemlyb/tests/test_preprocessing.py | 13 +++++++++++++ 4 files changed, 27 insertions(+) diff --git a/CHANGES b/CHANGES index 7a8ed754..cf4ff351 100644 --- a/CHANGES +++ b/CHANGES @@ -21,6 +21,10 @@ The rules for this file: Enhancements - Add a TI estimator using gaussian quadrature to calculate the free energy. (issue #302, PR #304) + - Warning issued when the series is `None` for `statistical_inefficiency` + (issue #337, PR #304) + - ValueError issued when `df` and `series` for `statistical_inefficiency` + doesn't have the same length (issue #337, PR #304) 22/06/2023 xiki-tempula diff --git a/environment.yml b/environment.yml index 4d2d9bda..da0522df 100644 --- a/environment.yml +++ b/environment.yml @@ -10,3 +10,4 @@ dependencies: - scikit-learn - pyarrow - matplotlib +- loguru diff --git a/src/alchemlyb/preprocessing/subsampling.py b/src/alchemlyb/preprocessing/subsampling.py index 4633a87e..1bb671e6 100644 --- a/src/alchemlyb/preprocessing/subsampling.py +++ b/src/alchemlyb/preprocessing/subsampling.py @@ -363,6 +363,15 @@ def _prepare_input(df, series, drop_duplicates, sort): series : Series Formatted Series. """ + if series is None: + logger.warning( + "The series input is `None`, would not subsample according to statistical inefficiency." + ) + + elif len(df) != len(series): + raise ValueError( + f"The length of df ({len(df)}) should be same as the length of series ({len(series)})." + ) if _check_multiple_times(df): if drop_duplicates: df, series = _drop_duplicates(df, series) diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 00e3c030..1f6c8e08 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -544,3 +544,16 @@ def test_statistical_inefficiency(self, caplog, u_nk): assert "Running statistical inefficiency analysis." in caplog.text assert "Statistical inefficiency:" in caplog.text assert "Number of uncorrelated samples:" in caplog.text + + +def test_unequil_input(dHdl): + with pytest.raises(ValueError, match="should be same as the length of series"): + statistical_inefficiency(dHdl, series=dHdl[:10]) + + +def test_series_none(dHdl, caplog): + statistical_inefficiency(dHdl, series=None) + assert ( + "The series input is `None`, would not subsample according to statistical inefficiency." + in caplog.text + )