diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index be8859c..5c4709a 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -22,8 +22,23 @@ concurrency: cancel-in-progress: true jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.9" + cache: pip + cache-dependency-path: pyproject.toml + + - uses: pre-commit/action@v3.0.0 + with: + extra_args: --files solvation_analysis/* + test: - name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} + name: pip install on ${{ matrix.os }}, Python ${{ matrix.python-version }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -34,43 +49,17 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Additional info about the build - shell: bash - run: | - uname -a - df -h - ulimit -a - - - # More info on options: https://github.com/conda-incubator/setup-miniconda - - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - environment-file: devtools/conda-envs/test_env.yaml - - channels: conda-forge,defaults - - activate-environment: test - auto-update-conda: false - auto-activate-base: false - show-channel-urls: true - name: Install package - - # conda setup requires this special shell - shell: bash -l {0} run: | - python -m pip install . --no-deps - conda list - + python -m pip install . - name: Run tests - - # conda setup requires this special shell - shell: bash -l {0} - run: | - pytest -v --cov=solvation_analysis --cov-report=xml --color=yes solvation_analysis/tests/ + pytest -v --color=yes solvation_analysis/tests/ - name: CodeCov uses: codecov/codecov-action@v1 @@ -78,27 +67,3 @@ jobs: file: ./coverage.xml flags: unittests name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} - - test_pip_install: - name: pip (PEP517) install on ${{ matrix.os }}, Python ${{ matrix.python-version }} - runs-on: ${{ matrix.os }} - strategy: - fail-fast: false - matrix: - os: [macOS-latest, ubuntu-latest, windows-latest] - python-version: [3.9, "3.10", 3.11, 3.12] - - steps: - - uses: actions/checkout@v3 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - - name: Install package - run: | - python -m pip install . - - - name: Run tests - run: | - pytest -v --color=yes solvation_analysis/tests/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f9d887e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +default_language_version: + python: python3 +repos: +- repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.4.2 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: fix-encoding-pragma + args: [--remove] + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 + hooks: + - id: python-use-type-annotations + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + files: ^src/ + additional_dependencies: + - tokenize-rt==4.1.0 + - types-paramiko +- repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + stages: [commit, commit-msg] + args: [--ignore-words-list, 'titel,statics,ba,nd,te,atomate'] + types_or: [python, rst, markdown] diff --git a/README.md b/README.md index f07c65d..e210f3c 100644 --- a/README.md +++ b/README.md @@ -37,9 +37,12 @@ conda install -c conda-forge solvation_analysis ### Contributing -Contributions, both issues and PRs, are welcome. If you'd like to contribute, we ask that you +Contributions, both issues and PRs, are welcome. If you'd like to contribute, we ask that you follow the community guidelines outlined in the [MDAnalysis Code of Conduct](https://www.mdanalysis.org/pages/conduct/). +Solvation Analysis uses [pre-commit](https://pre-commit.com/) for linting. Make sure to install +the pre-commit hooks if you are working on a contribution. + ### Citation This work is described in [JOSS](https://doi.org/10.21105/joss.05183), please cite it if you make @@ -47,7 +50,7 @@ use of this package in published work. --- -Project based on the +Project based on the [Computational Molecular Science Python Cookiecutter](https://github.com/molssi/cookiecutter-cms) version 1.5. diff --git a/docs/requirements.yaml b/docs/requirements.yaml index 06a2d1c..9472208 100644 --- a/docs/requirements.yaml +++ b/docs/requirements.yaml @@ -14,8 +14,8 @@ dependencies: - ipython - plotly - sphinx_rtd_theme + - scipy==1.12.0 # Pip-only installs #- pip: - diff --git a/docs/tutorials/basics_tutorial.ipynb b/docs/tutorials/basics_tutorial.ipynb index f235d3b..492e1e2 100644 --- a/docs/tutorials/basics_tutorial.ipynb +++ b/docs/tutorials/basics_tutorial.ipynb @@ -80,7 +80,7 @@ "outputs": [], "source": [ "# instantiate solute\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "\n", "solute = Solute.from_atoms(li_atoms, {'PF6': PF6, 'BN': BN, 'FEC': FEC}, solute_name=\"Li\")" ] @@ -239,14 +239,16 @@ " li_atoms,\n", " {'PF6': PF6, 'BN': BN, 'FEC': FEC},\n", " solute_name=\"Li\",\n", - " radii={\"PF6\": 2.6}\n", + " radii={\"PF6\": 2.6},\n", + " kernel_kwargs={\"default\": 3.0}\n", ")\n", "\n", "solute.run()" ], "metadata": { "collapsed": false - } + }, + "id": "e7a6367c65eaded3" }, { "cell_type": "code", diff --git a/docs/tutorials/clustering_and_residence_tutorial.ipynb b/docs/tutorials/clustering_and_residence_tutorial.ipynb index 84eec72..5f6f5dc 100644 --- a/docs/tutorials/clustering_and_residence_tutorial.ipynb +++ b/docs/tutorials/clustering_and_residence_tutorial.ipynb @@ -35,7 +35,7 @@ "source": [ "# imports\n", "import MDAnalysis as mda\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "\n", "# we will use a trajectory supplied by the package\n", "from solvation_analysis.tests import datafiles\n", diff --git a/docs/tutorials/multi_atom_solutes.ipynb b/docs/tutorials/multi_atom_solutes.ipynb index fc1d2e8..f24e43d 100644 --- a/docs/tutorials/multi_atom_solutes.ipynb +++ b/docs/tutorials/multi_atom_solutes.ipynb @@ -27,7 +27,7 @@ "outputs": [], "source": [ "import MDAnalysis as mda\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "\n", "from solvation_analysis.tests.datafiles import iba_data, iba_dcd\n", "\n", diff --git a/docs/tutorials/plotting_tutorial.ipynb b/docs/tutorials/plotting_tutorial.ipynb index 3748834..2ab73b7 100644 --- a/docs/tutorials/plotting_tutorial.ipynb +++ b/docs/tutorials/plotting_tutorial.ipynb @@ -21,7 +21,7 @@ "\n", "import plotly.io\n", "plotly.io.renderers.default = 'svg'\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "# this is a dict of dicts, {solute_name: {group_name: atom_group}}\n", "from setup_eax_solutes import u_eax_atom_groups" ] diff --git a/docs/tutorials/rdf_fitting_demo.ipynb b/docs/tutorials/rdf_fitting_demo.ipynb index 3ccb814..431fd78 100644 --- a/docs/tutorials/rdf_fitting_demo.ipynb +++ b/docs/tutorials/rdf_fitting_demo.ipynb @@ -24,7 +24,7 @@ "\n", "from solvation_analysis.rdf_parser import plot_scipy_find_peaks_troughs, identify_cutoff_scipy\n", "from solvation_analysis.tests import datafiles\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "\n", "from scipy.signal import find_peaks" ] diff --git a/docs/tutorials/visualization_tutorial.ipynb b/docs/tutorials/visualization_tutorial.ipynb index 1a4c724..38b100d 100644 --- a/docs/tutorials/visualization_tutorial.ipynb +++ b/docs/tutorials/visualization_tutorial.ipynb @@ -21,7 +21,7 @@ "source": [ "# imports\n", "import MDAnalysis as mda\n", - "from solvation_analysis.solute import Solute\n", + "from solvation_analysis import Solute\n", "from solvation_analysis.tests import datafiles\n", "\n", "from IPython.display import Image\n", diff --git a/pyproject.toml b/pyproject.toml index 3b3bdd8..1122fea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ 'pytest', 'matplotlib', 'setuptools', - 'scipy', + 'scipy==1.12.0', 'statsmodels', 'plotly', 'rdkit' diff --git a/requirements.txt b/requirements.txt index cc22965..044e8d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ setuptools numpy>=1.20.0 pandas>=2.2 -mdanalysis>=2.0.0 +mdanalysis>=2.7.0 pytest pathlib matplotlib -scipy +scipy==1.12.0 statsmodels plotly rdkit diff --git a/setup.py b/setup.py index 1222451..0cf0fa8 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ SolvationAnalysis An MDAnalysis rmodule for solvation analysis. """ + import sys from setuptools import setup, find_packages import versioneer @@ -9,50 +10,46 @@ short_description = __doc__.split("\n") # from https://github.com/pytest-dev/pytest-runner#conditional-requirement -needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) -pytest_runner = ['pytest-runner'] if needs_pytest else [] +needs_pytest = {"pytest", "test", "ptr"}.intersection(sys.argv) +pytest_runner = ["pytest-runner"] if needs_pytest else [] try: with open("README.md", "r") as handle: long_description = handle.read() -except: +except: # noqa long_description = "\n".join(short_description[2:]) setup( # Self-descriptive entries which should always be present - name='solvation_analysis', - author='Orion Cohen', - author_email='orioncohen@gmail.com', + name="solvation_analysis", + author="Orion Cohen", + author_email="orioncohen@gmail.com", description=short_description[0], long_description=long_description, long_description_content_type="text/markdown", version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), - license='GNU Public License v3', - + license="GNU Public License v3", # Which Python importable modules should be included when your package is installed # Handled automatically by setuptools. Use 'exclude' to prevent some specific # subpackage(s) from being added, if needed packages=find_packages(), - # Optional include package data to ship with your package # Customize MANIFEST.in if the general case does not suit your needs # Comment out this line to prevent the files from being packaged with your software include_package_data=True, - # Allows `setup.py test` to work correctly with pytest setup_requires=[] + pytest_runner, - install_requires=[ - 'numpy>=1.20.0', - 'mdanalysis>=2.7.0', - 'pandas', - 'matplotlib', - 'scipy', - 'statsmodels', - 'plotly', - 'rdkit' + "numpy>=1.20.0", + "mdanalysis>=2.7.0", + "pandas", + "matplotlib", + "scipy==1.12.0", + "statsmodels", + "plotly", + "rdkit", ], # Additional entries you may want simply uncomment the lines you want and fill in the data # url='http://www.my_package.com', # Website @@ -62,8 +59,6 @@ # 'Unix', # 'Windows'], # Valid platforms your code works on, adjust to your flavor # python_requires=">=3.5", # Python version restrictions - # Manual control if final package is compressible or not, set False to prevent the .egg from being made # zip_safe=False, - ) diff --git a/solvation_analysis/__init__.py b/solvation_analysis/__init__.py index c63d0b3..315d3f0 100644 --- a/solvation_analysis/__init__.py +++ b/solvation_analysis/__init__.py @@ -3,12 +3,17 @@ An MDAnalysis rmodule for solvation analysis. """ +from . import _version +from solvation_analysis.solute import Solute + # Handle versioneer from ._version import get_versions + versions = get_versions() -__version__ = versions['version'] -__git_revision__ = versions['full-revisionid'] +__version__ = versions["version"] +__git_revision__ = versions["full-revisionid"] del get_versions, versions -from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] +__all__ = ["Solute"] diff --git a/solvation_analysis/_column_names.py b/solvation_analysis/_column_names.py index af72b44..b637d5d 100644 --- a/solvation_analysis/_column_names.py +++ b/solvation_analysis/_column_names.py @@ -8,7 +8,6 @@ 1. change the variable name in all files """ - # for solvation_data FRAME = "frame" SOLUTE_IX = "solute_ix" diff --git a/solvation_analysis/_utils.py b/solvation_analysis/_utils.py index 10cb058..ab3f733 100644 --- a/solvation_analysis/_utils.py +++ b/solvation_analysis/_utils.py @@ -1,21 +1,24 @@ -import numpy as np from collections import defaultdict from functools import reduce +from typing import Union + +import numpy as np +import pandas as pd import MDAnalysis as mda from MDAnalysis.analysis import distances -from solvation_analysis._column_names import * +from solvation_analysis._column_names import FRAME, SOLUTE_IX, SOLVENT_IX, DISTANCE -def verify_solute_atoms(solute_atom_group): +def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomGroup]: # we presume that the solute_atoms has the same number of atoms on each residue # and that they all have the same indices on those residues # and that the residues are all the same length # then this should work all_res_len = np.array([res.atoms.n_atoms for res in solute_atom_group.residues]) - assert np.all(all_res_len[0] == all_res_len), ( - "All residues must be the same length." - ) + assert np.all( + all_res_len[0] == all_res_len + ), "All residues must be the same length." res_atom_local_ix = defaultdict(list) res_atom_ix = defaultdict(list) @@ -23,36 +26,38 @@ def verify_solute_atoms(solute_atom_group): res_atom_local_ix[atom.resindex].append(atom.ix - atom.residue.atoms[0].ix) res_atom_ix[atom.resindex].append(atom.index) res_occupancy = np.array([len(ix) for ix in res_atom_local_ix.values()]) - assert np.all(res_occupancy[0] == res_occupancy), ( - "All residues must have the same number of solute_atoms atoms on them." - ) + assert np.all( + res_occupancy[0] == res_occupancy + ), "All residues must have the same number of solute_atoms atoms on them." res_atom_array = np.array(list(res_atom_local_ix.values())) - assert np.all(res_atom_array[0] == res_atom_array), ( - "All residues must have the same solute_atoms atoms on them." - ) + assert np.all( + res_atom_array[0] == res_atom_array + ), "All residues must have the same solute_atoms atoms on them." res_atom_ix_array = np.array(list(res_atom_ix.values())) solute_atom_group_dict = {} for i in range(0, res_atom_ix_array.shape[1]): - solute_atom_group_dict[i] = solute_atom_group.universe.atoms[res_atom_ix_array[:, i]] + solute_atom_group_dict[i] = solute_atom_group.universe.atoms[ + res_atom_ix_array[:, i] + ] return solute_atom_group_dict -def verify_solute_atoms_dict(solute_atoms_dict): +def verify_solute_atoms_dict( + solute_atoms_dict: dict[str, mda.AtomGroup], +) -> mda.AtomGroup: # first we verify the input format atom_group_lengths = [] for solute_name, solute_atom_group in solute_atoms_dict.items(): - assert isinstance(solute_name, str), ( - "The keys of solutes_dict must be strings." - ) + assert isinstance(solute_name, str), "The keys of solutes_dict must be strings." assert isinstance(solute_atom_group, mda.AtomGroup), ( f"The values of solutes_dict must be MDAnalysis.AtomGroups. But the value" f"for {solute_name} is a {type(solute_atom_group)}." ) - assert len(solute_atom_group) == len(solute_atom_group.residues), ( - "The solute_atom_group must have a single atom on each residue." - ) + assert len(solute_atom_group) == len( + solute_atom_group.residues + ), "The solute_atom_group must have a single atom on each residue." atom_group_lengths.append(len(solute_atom_group)) assert np.all(np.array(atom_group_lengths) == atom_group_lengths[0]), ( "AtomGroups in solutes_dict must have the same length because there should be" @@ -60,16 +65,25 @@ def verify_solute_atoms_dict(solute_atoms_dict): ) # verify that the solute_atom_groups have no overlap - solute_atom_group = reduce(lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()]) - assert solute_atom_group.n_atoms == sum([atoms.n_atoms for atoms in solute_atoms_dict.values()]), ( - "The solute_atom_groups must not overlap." + solute_atom_group = reduce( + lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()] ) + assert solute_atom_group.n_atoms == sum( + [atoms.n_atoms for atoms in solute_atoms_dict.values()] + ), "The solute_atom_groups must not overlap." verify_solute_atoms(solute_atom_group) return solute_atom_group -def get_atom_group(selection): +def get_atom_group( + selection: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], +) -> mda.AtomGroup: """ Cast an MDAnalysis.Atom, MDAnalysis.Residue, or MDAnalysis.ResidueGroup to AtomGroup. @@ -101,12 +115,21 @@ def get_atom_group(selection): def get_closest_n_mol( - central_species, - n_mol, - guess_radius=3, - return_ordered_resix=False, - return_radii=False, -): + central_species: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], + n_mol: int, + guess_radius: Union[float, int] = 3, + return_ordered_resix: bool = False, + return_radii: bool = False, +) -> Union[ + mda.AtomGroup, + tuple[mda.AtomGroup, np.ndarray], + tuple[mda.AtomGroup, np.ndarray, np.ndarray], +]: """ Returns the closest n molecules to the central species, an array of their resix, and an array of the distance of the closest atom in each molecule. @@ -148,7 +171,7 @@ def get_closest_n_mol( n_mol, guess_radius + 1, return_ordered_resix=return_ordered_resix, - return_radii=return_radii + return_radii=return_radii, ) radii = distances.distance_array(coords, partial_shell.positions, box=u.dimensions)[ 0 @@ -156,7 +179,7 @@ def get_closest_n_mol( ordering = np.argsort(radii) ordered_resix = shell_resix[ordering] closest_n_resix = np.sort(np.unique(ordered_resix, return_index=True)[1])[ - 0: n_mol + 1 + 0 : n_mol + 1 ] str_resix = " ".join(str(resix) for resix in ordered_resix[closest_n_resix]) full_shell = u.select_atoms(f"resindex {str_resix}") @@ -174,7 +197,15 @@ def get_closest_n_mol( return full_shell -def get_radial_shell(central_species, radius): +def get_radial_shell( + central_species: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], + radius: Union[float, int], +) -> mda.AtomGroup: """ Returns all molecules with atoms within the radius of the central species. (specifically, within the radius of the COM of central species). @@ -199,7 +230,7 @@ def get_radial_shell(central_species, radius): return full_shell -def calculate_adjacency_dataframe(solvation_data): +def calculate_adjacency_dataframe(solvation_data: pd.DataFrame) -> pd.DataFrame: """ Calculate a frame-by-frame adjacency matrix from the solvation data. diff --git a/solvation_analysis/_version.py b/solvation_analysis/_version.py index 16cd611..6bd04f5 100644 --- a/solvation_analysis/_version.py +++ b/solvation_analysis/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,18 +57,19 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate # pylint:disable=too-many-arguments,consider-using-with # noqa -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -77,10 +77,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -115,15 +118,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -182,7 +191,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -191,7 +200,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -199,24 +208,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -231,8 +247,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -240,10 +255,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -258,8 +282,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -299,17 +322,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -318,10 +340,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -370,8 +394,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -400,8 +423,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -544,11 +566,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -572,9 +596,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -588,8 +616,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -598,13 +625,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -618,6 +648,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/solvation_analysis/coordination.py b/solvation_analysis/coordination.py index 226a479..0f49008 100644 --- a/solvation_analysis/coordination.py +++ b/solvation_analysis/coordination.py @@ -17,8 +17,18 @@ """ import pandas as pd +import MDAnalysis as mda +import solvation_analysis -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + SOLVENT_ATOM_IX, + ATOM_TYPE, + FRACTION, +) class Coordination: @@ -62,7 +72,14 @@ class Coordination: """ - def __init__(self, solvation_data, n_frames, n_solutes, solvent_counts, atom_group): + def __init__( + self, + solvation_data: pd.DataFrame, + n_frames: int, + n_solutes: int, + solvent_counts: dict[str, int], + atom_group: mda.core.groups.AtomGroup, + ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames self.n_solutes = n_solutes @@ -73,7 +90,7 @@ def __init__(self, solvation_data, n_frames, n_solutes, solvent_counts, atom_gro self._coordination_vs_random = self._calculate_coordination_vs_random() @staticmethod - def from_solute(solute): + def from_solute(solute: "solvation_analysis.Solute") -> "Coordination": """ Generate a Coordination object from a solute. @@ -94,44 +111,47 @@ def from_solute(solute): solute.u.atoms, ) - def _mean_cn(self): - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + def _mean_cn(self) -> tuple[dict[str, float], pd.DataFrame]: + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] cn_series = counts.groupby([SOLVENT, FRAME]).sum() / ( - self.n_solutes * self.n_frames + self.n_solutes * self.n_frames ) cn_by_frame = cn_series.unstack() cn_dict = cn_series.groupby([SOLVENT]).sum().to_dict() return cn_dict, cn_by_frame - def _calculate_coordinating_atoms(self, tol=0.005): + def _calculate_coordinating_atoms(self, tol: float = 0.005) -> pd.DataFrame: """ Determine which atom types are actually coordinating return the types of those atoms """ # lookup atom types atom_types = self.solvation_data.reset_index([SOLVENT_ATOM_IX]) - atom_types[ATOM_TYPE] = self.atom_group[atom_types[SOLVENT_ATOM_IX].values].types + atom_types[ATOM_TYPE] = self.atom_group[ + atom_types[SOLVENT_ATOM_IX].values + ].types # count atom types atoms_by_type = atom_types[[ATOM_TYPE, SOLVENT, SOLVENT_ATOM_IX]] type_counts = atoms_by_type.groupby([SOLVENT, ATOM_TYPE]).count() solvent_counts = type_counts.groupby([SOLVENT]).sum()[SOLVENT_ATOM_IX] # calculate fraction of each solvent_counts_list = [ - solvent_counts[solvent] for solvent in - type_counts.index.get_level_values(SOLVENT) + solvent_counts[solvent] + for solvent in type_counts.index.get_level_values(SOLVENT) ] type_fractions = type_counts[SOLVENT_ATOM_IX] / solvent_counts_list type_fractions.name = FRACTION # change index type type_fractions = ( - type_fractions - .reset_index(ATOM_TYPE) + type_fractions.reset_index(ATOM_TYPE) .astype({ATOM_TYPE: str}) .set_index(ATOM_TYPE, append=True) ) return type_fractions[type_fractions[FRACTION] > tol] - def _calculate_coordination_vs_random(self): + def _calculate_coordination_vs_random(self) -> dict[str, float]: """ Calculate the coordination number relative to random coordination. @@ -150,7 +170,7 @@ def _calculate_coordination_vs_random(self): return coordination_vs_random @property - def coordination_numbers(self): + def coordination_numbers(self) -> dict[str, float]: """ A dictionary where keys are residue names (str) and values are the mean coordination number of that residue (float). @@ -158,21 +178,21 @@ def coordination_numbers(self): return self._cn_dict @property - def coordination_numbers_by_frame(self): + def coordination_numbers_by_frame(self) -> pd.DataFrame: """ A DataFrame of the mean coordination number of in each frame of the trajectory. """ return self._cn_dict_by_frame @property - def coordinating_atoms(self): + def coordinating_atoms(self) -> pd.DataFrame: """ Fraction of each atom_type participating in solvation, calculated for each solvent. """ return self._coordinating_atoms @property - def coordination_vs_random(self): + def coordination_vs_random(self) -> dict[str, float]: """ Coordination number relative to random coordination. diff --git a/solvation_analysis/networking.py b/solvation_analysis/networking.py index b6a2146..19da5c2 100644 --- a/solvation_analysis/networking.py +++ b/solvation_analysis/networking.py @@ -16,13 +16,26 @@ as an attribute of the Solute class. This makes instantiating it and calculating the solvation data a non-issue. """ + +from typing import Union + import pandas as pd import numpy as np from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +import solvation_analysis from solvation_analysis._utils import calculate_adjacency_dataframe -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + SOLVENT, + SOLVENT_IX, + SOLUTE_IX, + NETWORK, + PAIRED, + NETWORKED, + ISOLATED, +) class Networking: @@ -61,11 +74,17 @@ class Networking: >>> networking = Networking.from_solute(solute, 'PF6') """ - def __init__(self, solvents, solvation_data, solute_res_ix, res_name_map): + def __init__( + self, + solvents: Union[str, list[str]], + solvation_data: pd.DataFrame, + solute_res_ix: np.ndarray, + res_name_map: pd.Series, + ) -> None: self.solvents = solvents self.solvation_data = solvation_data - solvent_present = np.isin(self.solvents, self.solvation_data[SOLVENT].unique()) # TODO: we need all analysis classes to run when there is no solvation_data + # solvent_present = np.isin(self.solvents, self.solvation_data[SOLVENT].unique()) # if not solvent_present.all(): # raise Exception(f"Solvent(s) {np.array(self.solvents)[~solvent_present]} not found in solvation data.") self.solute_res_ix = solute_res_ix @@ -73,11 +92,15 @@ def __init__(self, solvents, solvation_data, solute_res_ix, res_name_map): self.n_solute = len(solute_res_ix) self._network_df = self._generate_networks() self._network_sizes = self._calculate_network_sizes() - self._solute_status, self._solute_status_by_frame = self._calculate_solute_status() + self._solute_status, self._solute_status_by_frame = ( + self._calculate_solute_status() + ) self._solute_status = self._solute_status.to_dict() @staticmethod - def from_solute(solute, solvents): + def from_solute( + solute: "solvation_analysis.Solute", solvents: Union[str, list[str]] + ) -> "Networking": """ Generate a Networking object from a solute and solvent names. @@ -102,7 +125,7 @@ def from_solute(solute, solvents): ) @staticmethod - def _unwrap_adjacency_dataframe(df): + def _unwrap_adjacency_dataframe(df: pd.DataFrame) -> csr_matrix: # this class will transform the biadjacency matrix into a proper adjacency matrix connections = df.reset_index(FRAME).drop(columns=FRAME) idx = connections.columns.append(connections.index) @@ -111,7 +134,7 @@ def _unwrap_adjacency_dataframe(df): adjacency_matrix = csr_matrix(undirected) return adjacency_matrix - def _generate_networks(self): + def _generate_networks(self) -> pd.DataFrame: """ This function generates a dataframe containing all the solute-solvent networks in every frame of the simulation. The rough approach is as follows: @@ -121,7 +144,9 @@ def _generate_networks(self): 3. tabulate the solvent involved in each network and store in a DataFrame """ solvents = [self.solvents] if isinstance(self.solvents, str) else self.solvents - solvation_subset = self.solvation_data[np.isin(self.solvation_data[SOLVENT], solvents)] + solvation_subset = self.solvation_data[ + np.isin(self.solvation_data[SOLVENT], solvents) + ] # create adjacency matrix from solvation_subset graph = calculate_adjacency_dataframe(solvation_subset) network_arrays = [] @@ -135,16 +160,16 @@ def _generate_networks(self): ix_to_res_ix = np.concatenate([solvent_map, solute_map]) adjacency_df = Networking._unwrap_adjacency_dataframe(df) _, network = connected_components( - csgraph=adjacency_df, - directed=False, - return_labels=True + csgraph=adjacency_df, directed=False, return_labels=True ) - network_array = np.vstack([ - np.full(len(network), frame), # frame - network, # network - self.res_name_map[ix_to_res_ix], # res_names - ix_to_res_ix, # res index - ]).T + network_array = np.vstack( + [ + np.full(len(network), frame), # frame + network, # network + self.res_name_map[ix_to_res_ix], # res_names + ix_to_res_ix, # res index + ] + ).T network_arrays.append(network_array) # create and return network dataframe if len(network_arrays) == 0: @@ -153,20 +178,24 @@ def _generate_networks(self): all_clusters = np.concatenate(network_arrays) cluster_df = ( pd.DataFrame(all_clusters, columns=[FRAME, NETWORK, SOLVENT, SOLVENT_IX]) - .set_index([FRAME, NETWORK]) - .sort_values([FRAME, NETWORK]) + .set_index([FRAME, NETWORK]) + .sort_values([FRAME, NETWORK]) ) return cluster_df - def _calculate_network_sizes(self): + def _calculate_network_sizes(self) -> pd.DataFrame: # This utility calculates the network sizes and returns a convenient dataframe. cluster_df = self.network_df cluster_sizes = cluster_df.groupby([FRAME, NETWORK]).count() - size_counts = cluster_sizes.groupby([FRAME, SOLVENT]).count().unstack(fill_value=0) - size_counts.columns = size_counts.columns.droplevel(None) # the column value is None + size_counts = ( + cluster_sizes.groupby([FRAME, SOLVENT]).count().unstack(fill_value=0) + ) + size_counts.columns = size_counts.columns.droplevel( + None + ) # the column value is None return size_counts - def _calculate_solute_status(self): + def _calculate_solute_status(self) -> tuple[pd.Series, pd.DataFrame]: """ This utility calculates the fraction of each solute with a given "status". Namely, whether the solvent is "isolated", "paired" (with a single solvent), or @@ -176,13 +205,15 @@ def _calculate_solute_status(self): status = self.network_sizes.iloc[:, 0:0] status[PAIRED] = self.network_sizes.iloc[:, 0:1].sum(axis=1).astype(int) status[NETWORKED] = self.network_sizes.iloc[:, 1:].sum(axis=1).astype(int) - status[ISOLATED] = self.n_solute - status.loc[:, [PAIRED, NETWORKED]].sum(axis=1) + status[ISOLATED] = self.n_solute - status.loc[:, [PAIRED, NETWORKED]].sum( + axis=1 + ) status = status.loc[:, [ISOLATED, PAIRED, NETWORKED]] solute_status_by_frame = status / self.n_solute solute_status = solute_status_by_frame.mean() return solute_status, solute_status_by_frame - def get_network_res_ix(self, network_index, frame): + def get_network_res_ix(self, network_index: int, frame: int) -> np.ndarray: """ Return the indexes of all residues in a selected network. @@ -213,11 +244,13 @@ def get_network_res_ix(self, network_index, frame): """ - res_ix = self.network_df.loc[pd.IndexSlice[frame, network_index], SOLVENT_IX].values + res_ix = self.network_df.loc[ + pd.IndexSlice[frame, network_index], SOLVENT_IX + ].values return res_ix.astype(int) @property - def network_df(self): + def network_df(self) -> pd.DataFrame: """ The dataframe containing all networking data. the indices are the frame and network index, respectively. the columns are the solvent_name and res_ix. @@ -225,7 +258,7 @@ def network_df(self): return self._network_df @property - def network_sizes(self): + def network_sizes(self) -> pd.DataFrame: """ A dataframe of network sizes. the index is the frame. the column headers are network sizes, or the number of solutes + solvents in the network, so @@ -235,7 +268,7 @@ def network_sizes(self): return self._network_sizes @property - def solute_status(self): + def solute_status(self) -> dict[str, float]: """ A dictionary where the keys are the "status" of the solute and the values are the fraction of solute with that status, averaged over all frames. @@ -250,7 +283,7 @@ def solute_status(self): return self._solute_status @property - def solute_status_by_frame(self): + def solute_status_by_frame(self) -> pd.DataFrame: """ As described above, except organized into a dataframe where each row is a unique frame and the columns are "isolated", "paired", and "networked". diff --git a/solvation_analysis/pairing.py b/solvation_analysis/pairing.py index 3a93d91..42c9b9b 100644 --- a/solvation_analysis/pairing.py +++ b/solvation_analysis/pairing.py @@ -19,7 +19,14 @@ import pandas as pd import numpy as np -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + DISTANCE, +) class Pairing: @@ -56,17 +63,27 @@ class Pairing: {'BN': 1.0, 'FEC': 0.210, 'PF6': 0.120} """ - def __init__(self, solvation_data, n_frames, n_solutes, n_solvents): + def __init__( + self, + solvation_data: pd.DataFrame, + n_frames: int, + n_solutes: int, + n_solvents: dict[str, int], + ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames self.n_solutes = n_solutes self.solvent_counts = n_solvents self._solvent_pairing, self._pairing_by_frame = self._fraction_coordinated() self._fraction_free_solvents = self._fraction_free_solvent() - self._diluent_composition, self._diluent_composition_by_frame, self._diluent_counts = self._diluent_composition() + ( + self._diluent_composition, + self._diluent_composition_by_frame, + self._diluent_counts, + ) = self._diluent_composition() @staticmethod - def from_solute(solute): + def from_solute(solute: "solvation_analysis.Solute") -> "Pairing": """ Generate a Pairing object from a solute. @@ -83,12 +100,14 @@ def from_solute(solute): solute.solvation_data, solute.n_frames, solute.n_solutes, - solute.solvent_counts + solute.solvent_counts, ) - def _fraction_coordinated(self): + def _fraction_coordinated(self) -> tuple[dict[str, float], pd.DataFrame]: # calculate the fraction of solute coordinated with each solvent - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] pairing_series = counts.astype(bool).groupby([SOLVENT, FRAME]).sum() / ( self.n_solutes ) # mean coordinated overall @@ -97,16 +116,24 @@ def _fraction_coordinated(self): pairing_dict = pairing_normalized.groupby([SOLVENT]).sum().to_dict() return pairing_dict, pairing_by_frame - def _fraction_free_solvent(self): + def _fraction_free_solvent(self) -> dict[str, float]: # calculate the fraction of each solvent NOT coordinated with the solute - counts = self.solvation_data.groupby([FRAME, SOLVENT_IX, SOLVENT]).count()[DISTANCE] + counts = self.solvation_data.groupby([FRAME, SOLVENT_IX, SOLVENT]).count()[ + DISTANCE + ] totals = counts.groupby([SOLVENT]).count() / self.n_frames - n_solvents = np.array([self.solvent_counts[name] for name in totals.index.values]) + n_solvents = np.array( + [self.solvent_counts[name] for name in totals.index.values] + ) free_solvents = np.ones(len(totals)) - totals / n_solvents return free_solvents.to_dict() - def _diluent_composition(self): - coordinated_solvents = self.solvation_data.groupby([FRAME, SOLVENT]).nunique()[SOLVENT_IX] + def _diluent_composition( + self, + ) -> tuple[dict[str, float], pd.DataFrame, pd.DataFrame]: + coordinated_solvents = self.solvation_data.groupby([FRAME, SOLVENT]).nunique()[ + SOLVENT_IX + ] solvent_counts = pd.Series(self.solvent_counts) total_solvents = solvent_counts.reindex(coordinated_solvents.index, level=1) diluent_solvents = total_solvents - coordinated_solvents @@ -117,7 +144,7 @@ def _diluent_composition(self): return diluent_dict, diluent_by_frame, diluent_counts @property - def solvent_pairing(self): + def solvent_pairing(self) -> dict[str, float]: """ A dictionary where keys are residue names (str) and values are the fraction of solutes that contain that residue (float). @@ -125,14 +152,14 @@ def solvent_pairing(self): return self._solvent_pairing @property - def pairing_by_frame(self): + def pairing_by_frame(self) -> pd.DataFrame: """ A pd.Dataframe tracking the mean fraction of each residue across frames. """ return self._pairing_by_frame @property - def fraction_free_solvents(self): + def fraction_free_solvents(self) -> dict[str, float]: """ A dictionary containing the fraction of each solvent that is free. e.g. not coordinated to a solute. @@ -140,7 +167,7 @@ def fraction_free_solvents(self): return self._fraction_free_solvents @property - def diluent_composition(self): + def diluent_composition(self) -> dict[str, float]: """ The fraction of the diluent constituted by each solvent. The diluent is defined as everything that is not coordinated with the solute. @@ -148,14 +175,14 @@ def diluent_composition(self): return self._diluent_composition @property - def diluent_composition_by_frame(self): + def diluent_composition_by_frame(self) -> pd.DataFrame: """ A DataFrame of the diluent composition in each frame of the trajectory. """ return self._diluent_composition_by_frame @property - def diluent_counts(self): + def diluent_counts(self) -> pd.DataFrame: """ A DataFrame of the raw solvent counts in the diluent in each frame of the trajectory. """ diff --git a/solvation_analysis/plotting.py b/solvation_analysis/plotting.py index 6c1788a..22aa587 100644 --- a/solvation_analysis/plotting.py +++ b/solvation_analysis/plotting.py @@ -10,19 +10,20 @@ as their input and generating a Plotly.Figure object. """ -import plotly -import plotly.graph_objects as go -import plotly.express as px -import matplotlib +from typing import Union, Optional, Any, Callable from copy import deepcopy -from solvation_analysis.solute import Solute -import numpy as np +import plotly.graph_objects as go +import plotly.express as px import pandas as pd +from solvation_analysis.solute import Solute +from solvation_analysis.networking import Networking +from solvation_analysis.speciation import Speciation + # single solution -def plot_network_size_histogram(networking): +def plot_network_size_histogram(networking: Union[Networking, Solute]) -> go.Figure: """ Plot a histogram of network sizes. @@ -37,7 +38,7 @@ def plot_network_size_histogram(networking): """ if isinstance(networking, Solute): if not hasattr(networking, "networking"): - raise ValueError(f"Solute networking analysis class must be instantiated.") + raise ValueError("Solute networking analysis class must be instantiated.") networking = networking.networking network_sizes = networking.network_sizes sums = network_sizes.sum(axis=0) @@ -54,7 +55,7 @@ def plot_network_size_histogram(networking): return fig -def plot_shell_composition_by_size(speciation): +def plot_shell_composition_by_size(speciation: Union[Speciation, Solute]) -> go.Figure: """ Plot the composition of shells broken down by shell size. @@ -69,7 +70,7 @@ def plot_shell_composition_by_size(speciation): """ if isinstance(speciation, Solute): if not hasattr(speciation, "speciation"): - raise ValueError(f"Solute speciation analysis class must be instantiated.") + raise ValueError("Solute speciation analysis class must be instantiated.") speciation = speciation.speciation speciation_data = speciation.speciation_data.copy() speciation_data["total"] = speciation_data.sum(axis=1) @@ -90,11 +91,13 @@ def plot_shell_composition_by_size(speciation): return fig -def plot_co_occurrence(speciation, colorscale=None): +def plot_co_occurrence( + speciation: Union[Speciation, Solute], colorscale: Optional[Any] = None +) -> go.Figure: """ Plot the co-occurrence matrix of the solute using Plotly. - Co-occurrence represents the extent to which solvents occur with eachother + Co-occurrence represents the extent to which solvents occur with each other relative to random. Values higher than 1 mean that solvents occur together more often than random and values lower than 1 mean solvents occur together less often than random. "Random" is calculated based on the total number of @@ -111,7 +114,7 @@ def plot_co_occurrence(speciation, colorscale=None): """ if isinstance(speciation, Solute): if not hasattr(speciation, "speciation"): - raise ValueError(f"Solute speciation analysis class must be instantiated.") + raise ValueError("Solute speciation analysis class must be instantiated.") speciation = speciation.speciation solvent_names = speciation.speciation_data.columns.values @@ -124,9 +127,9 @@ def plot_co_occurrence(speciation, colorscale=None): range_val = max_val - min_val colorscale = [ - [0, 'rgb(67,147,195)'], + [0, "rgb(67,147,195)"], [(1 - min_val) / range_val, "white"], - [1, 'rgb(214,96,77)'] + [1, "rgb(214,96,77)"], ] # Create a heatmap trace with text annotations @@ -137,24 +140,24 @@ def plot_co_occurrence(speciation, colorscale=None): text=speciation.solvent_co_occurrence.round(2).to_numpy(dtype=str), # Keep the text annotations in the original order hoverinfo="none", - colorscale=colorscale + colorscale=colorscale, ) # Update layout to display tick labels and text annotations layout = go.Layout( title="Solvent Co-Occurrence Matrix", xaxis=dict( - tickmode='array', + tickmode="array", tickvals=list(range(len(solvent_names))), ticktext=solvent_names, tickangle=-30, - side='top' + side="top", ), yaxis=dict( - tickmode='array', + tickmode="array", tickvals=list(range(len(solvent_names))), ticktext=solvent_names, - autorange='reversed' + autorange="reversed", ), margin=dict(l=60, r=60, b=60, t=100, pad=4), annotations=[ @@ -163,7 +166,7 @@ def plot_co_occurrence(speciation, colorscale=None): y=j, text=str(round(speciation.solvent_co_occurrence.iloc[j, i], 2)), font=dict(size=14, color="black"), - showarrow=False + showarrow=False, ) for i in range(len(solvent_names)) for j in range(len(solvent_names)) @@ -176,13 +179,13 @@ def plot_co_occurrence(speciation, colorscale=None): def compare_solvent_dicts( - property_dict, - rename_solvent_dict, - solvents_to_plot, - legend_label, - x_axis="solvent", - series=False, -): + property_dict: dict[str, dict[str, float]], + rename_solvent_dict: dict[str, str], + solvents_to_plot: list[str], + legend_label: str, + x_axis: str = "solvent", + series: bool = False, +) -> go.Figure: """ A generic plotting function that can compare dictionary data between multiple solutes. @@ -226,7 +229,6 @@ def compare_solvent_dicts( set(solution_dict.keys()) for solution_dict in property_dict.values() ] valid_solvents = set.intersection(*all_solvents) - invalid_solvents = set.union(*all_solvents) - valid_solvents if not set(solvents_to_plot).issubset(valid_solvents): raise Exception( f"solvents_to_plot must only include solvents that are " @@ -282,11 +284,11 @@ def compare_solvent_dicts( def _compare_function_generator( - analysis_object, - attribute, - title, - top_level_docstring, -): + analysis_object: str, + attribute: str, + title: str, + top_level_docstring: str, +) -> Callable: def compare_func( solutions, rename_solvent_dict=None, @@ -329,7 +331,7 @@ def compare_func( return fig arguments_docstring = """ - + property_dict : dict of {str: dict} a dictionary with the solution name as keys and a dict of {str: float} as values, where each key is the name of the solvent of each solution and each value is the property of interest diff --git a/solvation_analysis/rdf_parser.py b/solvation_analysis/rdf_parser.py index 962d65b..e530899 100644 --- a/solvation_analysis/rdf_parser.py +++ b/solvation_analysis/rdf_parser.py @@ -10,6 +10,8 @@ from an RDF. """ +from typing import Any, Optional, Union + import numpy as np from scipy.interpolate import UnivariateSpline import scipy @@ -17,10 +19,10 @@ import warnings from scipy.signal import find_peaks, gaussian -from solvation_analysis._column_names import * - -def interpolate_rdf(bins, rdf, floor=0.05, cutoff=5): +def interpolate_rdf( + bins: np.ndarray, rdf: np.ndarray, floor: float = 0.05, cutoff: float = 5 +) -> tuple[UnivariateSpline, tuple[float, float]]: """ Fits a sciply.interpolate.UnivariateSpline to the starting region of the RDF. The floor and cutoff control the region of the RDF that the @@ -49,7 +51,7 @@ def interpolate_rdf(bins, rdf, floor=0.05, cutoff=5): return f, bounds -def identify_minima(f): +def identify_minima(f: UnivariateSpline) -> tuple[np.ndarray, np.ndarray]: """ Identifies the extrema of a interpolated polynomial. @@ -75,7 +77,9 @@ def identify_minima(f): return cr_pts, cr_vals -def plot_interpolation_fit(bins, rdf, **kwargs): +def plot_interpolation_fit( + bins: np.ndarray, rdf: np.ndarray, **kwargs: Any +) -> tuple[plt.Figure, plt.Axes]: """ Calls interpolate_rdf and identify_minima to identify the extrema of an RDF. Plots the original rdf, the interpolated spline, and the extrema of the @@ -109,7 +113,9 @@ def plot_interpolation_fit(bins, rdf, **kwargs): return fig, ax -def good_cutoff(cutoff_region, cr_pts, cr_vals): +def good_cutoff( + cutoff_region: tuple[float, float], cr_pts: np.ndarray, cr_vals: np.ndarray +) -> bool: """ Uses several heuristics to determine if the a solvation cutoff is valid solvation cutoff. This fails if there is no solvation shell. @@ -139,13 +145,20 @@ def good_cutoff(cutoff_region, cr_pts, cr_vals): return True -def good_cutoff_scipy(cutoff_region, min_trough_depth, peaks, troughs, rdf, bins): +def good_cutoff_scipy( + cutoff_region: tuple[float, float], + min_trough_depth: float, + peaks: np.ndarray, + troughs: np.ndarray, + rdf: np.ndarray, + bins: np.ndarray, +) -> bool: """ Uses several heuristics to determine if the solvation cutoff is valid solvation cutoff. This fails if there is no solvation shell. Heuristics: - - trough follows peak + - troughs follows peaks - in `Solute.cutoff_region` (specified by kwarg) - normalized peak height > 0.05 @@ -154,7 +167,7 @@ def good_cutoff_scipy(cutoff_region, min_trough_depth, peaks, troughs, rdf, bins cutoff_region : tuple boundaries in which to search for a solvation shell cutoff, i.e. (1.5, 4) min_trough_depth : float - the minimum depth of a trough to be considered a valid solvation cutoff + the minimum depth to be considered a valid solvation cutoff peaks : np.array the indices of the peaks in the bins array troughs : np.array @@ -171,16 +184,22 @@ def good_cutoff_scipy(cutoff_region, min_trough_depth, peaks, troughs, rdf, bins # normalize rdf norm_rdf = rdf / np.max(rdf) if ( - len(peaks) == 0 or len(troughs) == 0 # insufficient critical points + len(peaks) == 0 + or len(troughs) == 0 # insufficient critical points or troughs[0] < peaks[0] # not a min and max - or not (cutoff_region[0] < bins[troughs[0]] < cutoff_region[1]) # min not in cutoff - or abs(norm_rdf[peaks[0]] - norm_rdf[troughs[0]]) < min_trough_depth # peak too small + or not ( + cutoff_region[0] < bins[troughs[0]] < cutoff_region[1] + ) # min not in cutoff + or abs(norm_rdf[peaks[0]] - norm_rdf[troughs[0]]) + < min_trough_depth # peak too small ): return False return True -def scipy_find_peaks_troughs(bins, rdf, return_rdf=False, **kwargs): +def scipy_find_peaks_troughs( + bins: np.ndarray, rdf: np.ndarray, return_rdf: bool = False, **kwargs: Any +) -> Union[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Finds the indices of the peaks and troughs of an RDF. @@ -223,14 +242,14 @@ def scipy_find_peaks_troughs(bins, rdf, return_rdf=False, **kwargs): def identify_cutoff_scipy( - bins, - rdf, - cutoff_region=(1.5, 4), - failure_behavior="warn", - min_trough_depth=0.02, - default=None, - **kwargs -): + bins: np.ndarray, + rdf: np.ndarray, + cutoff_region: tuple[float, float] = (1.5, 4), + failure_behavior: str = "warn", + min_trough_depth: float = 0.02, + default: Optional[float] = None, + **kwargs: Any, +) -> Optional[float]: """ Identifies the solvation cutoff of an RDF. @@ -252,7 +271,7 @@ def identify_cutoff_scipy( default : float, optional the value to return if no solvation shell is found min_trough_depth : float - the minimum depth of a trough to be considered a valid solvation cutoff + the minimum depth of troughs to be considered a valid solvation cutoff kwargs : passed to the scipy.find_peaks function Returns @@ -261,25 +280,29 @@ def identify_cutoff_scipy( the solvation cutoff of the RDF """ peaks, troughs = scipy_find_peaks_troughs(bins, rdf, **kwargs) - if not good_cutoff_scipy(cutoff_region, min_trough_depth, peaks, troughs, rdf, bins): + if not good_cutoff_scipy( + cutoff_region, min_trough_depth, peaks, troughs, rdf, bins + ): if failure_behavior == "silent": return default if failure_behavior == "warn": warnings.warn("No solvation shell detected.") return default if failure_behavior == "exception": - raise RuntimeError("Solute could not identify a solvation radius for at least one solvent. " - "Please enter the missing radii manually by adding them to the radii dict" - "and rerun the analysis.") + raise RuntimeError( + "Solute could not identify a solvation radius for at least one solvent. " + "Please enter the missing radii manually by adding them to the radii dict" + "and rerun the analysis." + ) cutoff = bins[troughs[0]] return cutoff def plot_scipy_find_peaks_troughs( - bins, - rdf, - **kwargs, -): + bins: np.ndarray, + rdf: np.ndarray, + **kwargs: Any, +) -> tuple[plt.Figure, plt.Axes]: """ Plot the original and smoothed RDF with the peaks and troughs located. @@ -300,7 +323,9 @@ def plot_scipy_find_peaks_troughs( fig, ax : matplotlib pyplot Figure and Axis for the fit """ - peaks, troughs, smooth_rdf = scipy_find_peaks_troughs(bins, rdf, return_rdf=True, **kwargs) + peaks, troughs, smooth_rdf = scipy_find_peaks_troughs( + bins, rdf, return_rdf=True, **kwargs + ) fig, ax = plt.subplots() ax.plot(bins, rdf, "b--", label="rdf") ax.plot(bins, smooth_rdf, "g-", label="smooth_rdf") @@ -314,8 +339,13 @@ def plot_scipy_find_peaks_troughs( def identify_cutoff_poly( - bins, rdf, failure_behavior="warn", cutoff_region=(1.5, 4), floor=0.05, cutoff=5 -): + bins: np.ndarray, + rdf: np.ndarray, + failure_behavior: str = "warn", + cutoff_region: tuple[float, float] = (1.5, 4), + floor: float = 0.05, + cutoff: float = 5, +) -> float: """ Identifies the solvation cutoff of an RDF using a polynomial interpolation. @@ -349,7 +379,9 @@ def identify_cutoff_poly( warnings.warn("No solvation shell detected.") return np.NaN if failure_behavior == "exception": - raise RuntimeError("Solute could not identify a solvation radius for at least one solvent. " - "Please enter the missing radii manually by adding them to the radii dict" - "and rerun the analysis.") + raise RuntimeError( + "Solute could not identify a solvation radius for at least one solvent. " + "Please enter the missing radii manually by adding them to the radii dict" + "and rerun the analysis." + ) return cr_pts[1] diff --git a/solvation_analysis/residence.py b/solvation_analysis/residence.py index 1ee4bee..04824ab 100644 --- a/solvation_analysis/residence.py +++ b/solvation_analysis/residence.py @@ -15,6 +15,7 @@ as an attribute of the Solute class. This makes instantiating it and calculating the solvation data a non-issue. """ + import math import warnings @@ -24,7 +25,13 @@ from statsmodels.tsa.stattools import acovf from scipy.optimize import curve_fit -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + SOLVENT, + SOLUTE_ATOM_IX, + SOLVENT_ATOM_IX, + SOLUTE_IX, +) from solvation_analysis._utils import calculate_adjacency_dataframe @@ -83,17 +90,18 @@ class Residence: {'BN': 4.02, 'FEC': 3.79, 'PF6': 1.15} """ - def __init__(self, solvation_data, step): + def __init__(self, solvation_data: pd.DataFrame, step: int) -> None: self.solvation_data = solvation_data self._auto_covariances = self._calculate_auto_covariance_dict() - self._residence_times_cutoff = self._calculate_residence_times_with_cutoff(self._auto_covariances, step) - self._residence_times_fit, self._fit_parameters = self._calculate_residence_times_with_fit( - self._auto_covariances, - step + self._residence_times_cutoff = self._calculate_residence_times_with_cutoff( + self._auto_covariances, step + ) + self._residence_times_fit, self._fit_parameters = ( + self._calculate_residence_times_with_fit(self._auto_covariances, step) ) @staticmethod - def from_solute(solute): + def from_solute(solute: "solvation_analysis.Solute") -> "Residence": """ Generate a Residence object from a solute. @@ -106,15 +114,14 @@ def from_solute(solute): Residence """ assert solute.has_run, "The solute must be run before calling from_solute" - return Residence( - solute.solvation_data, - solute.step - ) + return Residence(solute.solvation_data, solute.step) - def _calculate_auto_covariance_dict(self): + def _calculate_auto_covariance_dict(self) -> dict[str, np.ndarray]: partial_index = self.solvation_data.index.droplevel(SOLVENT_ATOM_IX) unique_indices = np.unique(partial_index) - frame_solute_index = pd.MultiIndex.from_tuples(unique_indices, names=partial_index.names) + frame_solute_index = pd.MultiIndex.from_tuples( + unique_indices, names=partial_index.names + ) auto_covariance_dict = {} for res_name, res_solvation_data in self.solvation_data.groupby([SOLVENT]): if isinstance(res_name, tuple): @@ -128,14 +135,20 @@ def _calculate_auto_covariance_dict(self): return auto_covariance_dict @staticmethod - def _calculate_residence_times_with_cutoff(auto_covariances, step, convergence_cutoff=0.1): + def _calculate_residence_times_with_cutoff( + auto_covariances: dict[str, np.ndarray], + step: int, + convergence_cutoff: float = 0.1, + ) -> dict[str, float]: residence_times = {} for res_name, auto_covariance in auto_covariances.items(): if np.min(auto_covariance) > convergence_cutoff: residence_times[res_name] = np.nan - warnings.warn(f'the autocovariance for {res_name} does not converge to zero ' - 'so a residence time cannot be calculated. A longer simulation ' - 'is required to get a valid estimate of the residence time.') + warnings.warn( + f"the autocovariance for {res_name} does not converge to zero " + "so a residence time cannot be calculated. A longer simulation " + "is required to get a valid estimate of the residence time." + ) unassigned = True for frame, val in enumerate(auto_covariance): if val < 1 / math.e: @@ -147,16 +160,21 @@ def _calculate_residence_times_with_cutoff(auto_covariances, step, convergence_c return residence_times @staticmethod - def _calculate_residence_times_with_fit(auto_covariances, step): + def _calculate_residence_times_with_fit( + auto_covariances: dict[str, np.ndarray], step: int + ) -> tuple[dict[str, float], dict[str, tuple[float, float, float]]]: # calculate the residence times residence_times = {} fit_parameters = {} for res_name, auto_covariance in auto_covariances.items(): res_time, params = Residence._fit_exponential(auto_covariance, res_name) - residence_times[res_name], fit_parameters[res_name] = round(res_time * step, 2), params + residence_times[res_name], fit_parameters[res_name] = ( + round(res_time * step, 2), + params, + ) return residence_times, fit_parameters - def plot_auto_covariance(self, res_name): + def plot_auto_covariance(self, res_name: str) -> tuple[plt.Figure, plt.Axes]: """ Plot the autocovariance of a solvent on the solute. @@ -175,16 +193,21 @@ def plot_auto_covariance(self, res_name): auto_covariance = self.auto_covariances[res_name] frames = np.arange(len(auto_covariance)) params = self.fit_parameters[res_name] - exp_func = lambda x: self._exponential_decay(x, *params) + + def exp_func(x): + return self._exponential_decay(x, *params) + exp_fit = np.array(map(exp_func, frames)) fig, ax = plt.subplots() ax.plot(frames, auto_covariance, "b-", label="auto covariance") try: ax.scatter(frames, exp_fit, label="exponential fit") - except: - warnings.warn(f'The fit for {res_name} failed so the exponential ' - f'fit will not be plotted.') - ax.hlines(y=1/math.e, xmin=frames[0], xmax=frames[-1], label='1/e cutoff') + except RuntimeError: + warnings.warn( + f"The fit for {res_name} failed so the exponential " + f"fit will not be plotted." + ) + ax.hlines(y=1 / math.e, xmin=frames[0], xmax=frames[-1], label="1/e cutoff") ax.set_xlabel("Timestep (frames)") ax.set_ylabel("Normalized Autocovariance") ax.set_ylim(0, 1) @@ -192,7 +215,7 @@ def plot_auto_covariance(self, res_name): return fig, ax @staticmethod - def _exponential_decay(x, a, b, c): + def _exponential_decay(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray: """ An exponential decay function. @@ -208,7 +231,9 @@ def _exponential_decay(x, a, b, c): return a * np.exp(-b * x) + c @staticmethod - def _fit_exponential(auto_covariance, res_name): + def _fit_exponential( + auto_covariance: np.ndarray, res_name: str + ) -> tuple[float, tuple[float, float, float]]: auto_covariance_norm = auto_covariance / auto_covariance[0] try: params, param_covariance = curve_fit( @@ -219,28 +244,34 @@ def _fit_exponential(auto_covariance, res_name): ) tau = 1 / params[1] # p except RuntimeError: - warnings.warn(f'The fit for {res_name} failed so its values in' - f'residence_time_fits and fit_parameters will be' - f'set to np.nan.') + warnings.warn( + f"The fit for {res_name} failed so its values in" + f"residence_time_fits and fit_parameters will be" + f"set to np.nan." + ) tau, params = np.nan, (np.nan, np.nan, np.nan) return tau, params @staticmethod - def _calculate_auto_covariance(adjacency_matrix): + def _calculate_auto_covariance(adjacency_matrix: pd.DataFrame) -> np.ndarray: auto_covariances = [] timesteps = adjacency_matrix.index.levels[0] - for solute_ix, solute_df in adjacency_matrix.groupby([SOLUTE_IX, SOLUTE_ATOM_IX]): + for solute_ix, solute_df in adjacency_matrix.groupby( + [SOLUTE_IX, SOLUTE_ATOM_IX] + ): # this is needed to make sure auto-covariances can be concatenated later - new_solute_df = solute_df.droplevel([SOLUTE_IX, SOLUTE_ATOM_IX]).reindex(timesteps, fill_value=0) + new_solute_df = solute_df.droplevel([SOLUTE_IX, SOLUTE_ATOM_IX]).reindex( + timesteps, fill_value=0 + ) non_zero_cols = new_solute_df.loc[:, (solute_df != 0).any(axis=0)] auto_covariance_df = non_zero_cols.apply( acovf, axis=0, - result_type='expand', + result_type="expand", demean=False, adjusted=True, - fft=True + fft=True, ) # timesteps with no binding are getting skipped, we need to make sure to include all timesteps auto_covariances.append(auto_covariance_df.values) @@ -249,7 +280,7 @@ def _calculate_auto_covariance(adjacency_matrix): return auto_covariance @property - def auto_covariances(self): + def auto_covariances(self) -> dict[str, np.ndarray]: """ A dictionary where keys are residue names and values are the autocovariance of the that residue on the solute. @@ -257,7 +288,7 @@ def auto_covariances(self): return self._auto_covariances @property - def residence_times_cutoff(self): + def residence_times_cutoff(self) -> dict[str, float]: """ A dictionary where keys are residue names and values are the residence times of the that residue on the solute, calculated @@ -266,7 +297,7 @@ def residence_times_cutoff(self): return self._residence_times_cutoff @property - def residence_times_fit(self): + def residence_times_fit(self) -> dict[str, float]: """ A dictionary where keys are residue names and values are the residence times of the that residue on the solute, calculated @@ -275,9 +306,9 @@ def residence_times_fit(self): return self._residence_times_fit @property - def fit_parameters(self): + def fit_parameters(self) -> dict[str, tuple[float, float, float]]: """ A dictionary where keys are residue names and values are the - arameters for the exponential fit to the autocorrelation function. + parameters for the exponential fit to the autocorrelation function. """ return self._fit_parameters diff --git a/solvation_analysis/solute.py b/solvation_analysis/solute.py index 34ef874..0e71d62 100644 --- a/solvation_analysis/solute.py +++ b/solvation_analysis/solute.py @@ -105,8 +105,10 @@ solvation shell, returning an AtomGroup for visualization or further analysis. This is covered in the visualization tutorial. """ + from collections import defaultdict from functools import reduce +from typing import Any, Callable, Optional, Union import matplotlib.pyplot as plt import pandas as pd @@ -118,7 +120,12 @@ from MDAnalysis.lib.distances import capped_distance import numpy as np -from solvation_analysis._utils import verify_solute_atoms, verify_solute_atoms_dict, get_closest_n_mol, get_radial_shell +from solvation_analysis._utils import ( + verify_solute_atoms, + verify_solute_atoms_dict, + get_closest_n_mol, + get_radial_shell, +) from solvation_analysis.rdf_parser import identify_cutoff_scipy from solvation_analysis.coordination import Coordination from solvation_analysis.networking import Networking @@ -126,7 +133,25 @@ from solvation_analysis.residence import Residence from solvation_analysis.speciation import Speciation -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + DISTANCE, + SOLUTE, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + SOLVENT_ATOM_IX, + SOLUTE_ATOM_IX, +) + +try: + import rdkit + from rdkit.Chem import Draw + from rdkit.Chem import rdCoordGen + from rdkit.Chem.Draw.MolDrawing import DrawingOptions + +except ImportError: + rdkit = None class Solute(AnalysisBase): @@ -160,9 +185,10 @@ class Solute(AnalysisBase): rdf_kernel : function, optional this function must take RDF bins and data as input and return a solvation radius as output. e.g. rdf_kernel(bins, data) -> 3.2. By default, - the rdf_kernel is solvation_analysis.rdf_parser.identify_solvation_cutoff. + the rdf_kernel is `solvation_analysis.rdf_parser.identify_cutoff_scipy`. kernel_kwargs : dict, optional - kwargs passed to rdf_kernel + kwargs passed to rdf_kernel. See `identify_cutoff_scipy` for options. This can + be used to set a default fallback radius for all solvents. rdf_init_kwargs : dict, optional kwargs passed to the initialization of the MDAnalysis.InterRDF used to plot the solute-solvent RDFs. By default, ``range`` will be set to (0, 7.5). @@ -235,29 +261,31 @@ class Solute(AnalysisBase): """ def __init__( - self, - solute_atoms, - solvents, - atom_solutes=None, - radii=None, - rdf_kernel=None, - kernel_kwargs=None, - rdf_init_kwargs=None, - rdf_run_kwargs=None, - skip_rdf=False, - solute_name="solute_0", - analysis_classes=None, - networking_solvents=None, - verbose=False, - internal_call=False, + self, + solute_atoms: mda.AtomGroup, + solvents: dict[str, mda.AtomGroup], + atom_solutes: Optional[dict[str, "Solute"]] = None, + radii: Optional[dict[str, float]] = None, + rdf_kernel: Optional[Callable[[np.ndarray, np.ndarray], float]] = None, + kernel_kwargs: Optional[dict[str, Any]] = None, + rdf_init_kwargs: Optional[dict[str, Any]] = None, + rdf_run_kwargs: Optional[dict[str, Any]] = None, + skip_rdf: bool = False, + solute_name: str = "solute_0", + analysis_classes: Optional[list[str]] = None, + networking_solvents: Optional[str] = None, + verbose: bool = False, + internal_call: bool = False, ): """ This method is not intended to be called directly. Instead, use ``from_atoms``, ``from_atoms_dict`` or ``from_solute_list`` to create a Solute. """ if not internal_call: - raise RuntimeError("Please use Solute.from_atoms, Solute.from_atoms_dict, or " - "Solute.from_solute_list instead of the default constructor.") + raise RuntimeError( + "Please use Solute.from_atoms, Solute.from_atoms_dict, or " + "Solute.from_solute_list instead of the default constructor." + ) super(Solute, self).__init__(solute_atoms.universe.trajectory, verbose=verbose) self.solute_atoms = solute_atoms # TODO: this shit! @@ -265,7 +293,9 @@ def __init__( if self.atom_solutes is None or len(atom_solutes) <= 1: self.atom_solutes = {solute_name: self} self.radii = radii or {} - self.solvent_counts = {name: atoms.n_residues for name, atoms in solvents.items()} + self.solvent_counts = { + name: atoms.n_residues for name, atoms in solvents.items() + } self.kernel = rdf_kernel or identify_cutoff_scipy self.kernel_kwargs = kernel_kwargs or {} self.rdf_init_kwargs = rdf_init_kwargs or {} @@ -273,17 +303,19 @@ def __init__( self.has_run = False self.u = solute_atoms.universe self.n_solutes = solute_atoms.n_residues - self.solute_res_ix = pd.Series(solute_atoms.atoms.resindices, solute_atoms.atoms.ix) + self.solute_res_ix = pd.Series( + solute_atoms.atoms.resindices, solute_atoms.atoms.ix + ) self.solute_name = solute_name self.solvents = solvents if skip_rdf: - assert set(self.radii.keys()) >= set(self.solvents.keys()), ( - "To skip RDF generation, all solvent radii must be specified." - ) + assert set(self.radii.keys()) >= set( + self.solvents.keys() + ), "To skip RDF generation, all solvent radii must be specified." self.skip_rdf = skip_rdf # instantiate the res_name_map - self.res_name_map = pd.Series(['none'] * len(self.u.residues)) + self.res_name_map = pd.Series(["none"] * len(self.u.residues)) self.res_name_map[solute_atoms.residues.ix] = self.solute_name for name, solvent in solvents.items(): self.res_name_map[solvent.residues.ix] = name @@ -291,8 +323,14 @@ def __init__( # instantiate analysis classes. if analysis_classes is None: self.analysis_classes = ["pairing", "coordination", "speciation"] - elif analysis_classes == 'all': - self.analysis_classes = ["pairing", "coordination", "speciation", "residence", "networking"] + elif analysis_classes == "all": + self.analysis_classes = [ + "pairing", + "coordination", + "speciation", + "residence", + "networking", + ] else: self.analysis_classes = [cls.lower() for cls in analysis_classes] if "networking" in self.analysis_classes and networking_solvents is None: @@ -303,7 +341,12 @@ def __init__( self.networking_solvents = networking_solvents @staticmethod - def from_atoms(solute_atoms, solvents, rename_solutes=None, **kwargs): + def from_atoms( + solute_atoms: mda.AtomGroup, + solvents: dict[str, mda.AtomGroup], + rename_solutes: Optional[dict[str, str]] = None, + **kwargs: Any, + ) -> "Solute": """ Create a Solute from a single AtomGroup. The solute_atoms AtomGroup must should contain identical residues and identical atoms on each residue. @@ -343,10 +386,16 @@ def from_atoms(solute_atoms, solvents, rename_solutes=None, **kwargs): rename_solutes.get(f"solute_{i}") or f"solute_{i}": atom_group for i, atom_group in solute_atom_group_dict.items() } - return Solute.from_atoms_dict(solute_atom_group_dict_renamed, solvents, **kwargs) + return Solute.from_atoms_dict( + solute_atom_group_dict_renamed, solvents, **kwargs + ) @staticmethod - def from_atoms_dict(solute_atoms_dict, solvents, **kwargs): + def from_atoms_dict( + solute_atoms_dict: dict[str, mda.AtomGroup], + solvents: dict[str, mda.AtomGroup], + **kwargs: Any, + ) -> "Solute": """ Create a Solute object from a dictionary of solute atoms. @@ -368,7 +417,7 @@ def from_atoms_dict(solute_atoms_dict, solvents, **kwargs): """ # all solute AtomGroups in one AtomGroup + verification - assert isinstance(solute_atoms_dict, dict), ("Solute_atoms_dict must be a dict.") + assert isinstance(solute_atoms_dict, dict), "Solute_atoms_dict must be a dict." solute_atom_group = verify_solute_atoms_dict(solute_atoms_dict) # create the solutes for each atom @@ -378,7 +427,7 @@ def from_atoms_dict(solute_atoms_dict, solvents, **kwargs): atoms, solvents, internal_call=True, - **{**kwargs, "solute_name": solute_name} + **{**kwargs, "solute_name": solute_name}, ) # create the solute for the whole solute solute = Solute( @@ -386,14 +435,16 @@ def from_atoms_dict(solute_atoms_dict, solvents, **kwargs): solvents, atom_solutes=atom_solutes, internal_call=True, - **kwargs + **kwargs, ) if len(atom_solutes) > 1: - solute.run = solute._run_solute_atoms + solute.run = solute._run_solute_atoms return solute @staticmethod - def from_solute_list(solutes, solvents, **kwargs): + def from_solute_list( + solutes: list["Solute"], solvents: dict[str, mda.AtomGroup], **kwargs: Any + ) -> "Solute": """ Create a Solute from a list of Solutes. All Solutes must have only a single solute atom on each solute residue. Essentially, from_solute_list @@ -420,15 +471,17 @@ def from_solute_list(solutes, solvents, **kwargs): # check types and name uniqueness for solute in solutes: assert type(solute) == Solute, "solutes must be a list of Solute objects." - assert len(solute.solute_atoms.atoms) == len(solute.solute_atoms.atoms.residues), ( - "Each Solute in solutes must have only a single atom per residue." - ) + assert len(solute.solute_atoms.atoms) == len( + solute.solute_atoms.atoms.residues + ), "Each Solute in solutes must have only a single atom per residue." solute_names = [solute.solute_name for solute in solutes] - assert len(np.unique(solute_names)) == len(solute_names), ( - "The solute_name for each solute must be unique." - ) + assert len(np.unique(solute_names)) == len( + solute_names + ), "The solute_name for each solute must be unique." - solute_atom_group = reduce(lambda x, y: x | y, [solute.solute_atoms for solute in solutes]) + solute_atom_group = reduce( + lambda x, y: x | y, [solute.solute_atoms for solute in solutes] + ) verify_solute_atoms(solute_atom_group) atom_solutes = {solute.solute_name: solute for solute in solutes} @@ -437,13 +490,19 @@ def from_solute_list(solutes, solvents, **kwargs): solvents, atom_solutes=atom_solutes, internal_call=True, - **kwargs + **kwargs, ) if len(atom_solutes) > 1: - solute.run = solute._run_solute_atoms + solute.run = solute._run_solute_atoms return solute - def _run_solute_atoms(self, start=None, stop=None, step=None, verbose=None): + def _run_solute_atoms( + self, + start: Optional[int] = None, + stop: Optional[int] = None, + step: Optional[int] = None, + verbose: Optional[bool] = None, + ): # like prepare atom_solutes = {} rdf_data = {} @@ -458,13 +517,17 @@ def _run_solute_atoms(self, start=None, stop=None, step=None, verbose=None): if not solute.has_run: solute.run(start=start, stop=stop, step=step, verbose=verbose) if (start, stop, step) != (solute.start, solute.stop, solute.step): - warnings.warn(f"The start, stop, or step for {solute.solute_name} do not" - f"match the start, stop, or step for the run command so it " - f"is being re-run.") + warnings.warn( + f"The start, stop, or step for {solute.solute_name} do not" + f"match the start, stop, or step for the run command so it " + f"is being re-run." + ) solute.run(start=start, stop=stop, step=step, verbose=verbose) if self.solvents != solute.solvents: - warnings.warn(f"The solvents for {solute.solute_name} do not match the " - f"solvents for the run command so it is being re-run.") + warnings.warn( + f"The solvents for {solute.solute_name} do not match the " + f"solvents for the run command so it is being re-run." + ) solute.run(start=start, stop=stop, step=step, verbose=verbose) atom_solutes[solute.solute_name] = solute rdf_data[solute.solute_name] = solute.rdf_data[solute.solute_name] @@ -478,17 +541,25 @@ def _run_solute_atoms(self, start=None, stop=None, step=None, verbose=None): # like conclude analysis_classes = { - 'speciation': Speciation, - 'pairing': Pairing, - 'coordination': Coordination, - 'residence': Residence, - 'networking': Networking, + "speciation": Speciation, + "pairing": Pairing, + "coordination": Coordination, + "residence": Residence, + "networking": Networking, } for analysis_class in self.analysis_classes: - if analysis_class == 'networking': - setattr(self, 'networking', Networking.from_solute(self, self.networking_solvents)) + if analysis_class == "networking": + setattr( + self, + "networking", + Networking.from_solute(self, self.networking_solvents), + ) else: - setattr(self, analysis_class, analysis_classes[analysis_class].from_solute(self)) + setattr( + self, + analysis_class, + analysis_classes[analysis_class].from_solute(self), + ) def _prepare(self): """ @@ -503,14 +574,22 @@ def _prepare(self): self.rdf_data = None break # set kwargs with defaults - self.rdf_init_kwargs["range"] = self.rdf_init_kwargs.get("range") or (0, 7.5) + self.rdf_init_kwargs["range"] = self.rdf_init_kwargs.get("range") or ( + 0, + 7.5, + ) self.rdf_init_kwargs["norm"] = self.rdf_init_kwargs.get("norm") or "density" self.rdf_run_kwargs["stop"] = self.rdf_run_kwargs.get("stop") or self.stop self.rdf_run_kwargs["step"] = self.rdf_run_kwargs.get("step") or self.step - self.rdf_run_kwargs["start"] = self.rdf_run_kwargs.get("start") or self.start + self.rdf_run_kwargs["start"] = ( + self.rdf_run_kwargs.get("start") or self.start + ) # generate and save RDFs rdf = InterRDF( - self.solute_atoms, solvent, **self.rdf_init_kwargs, exclude_same="residue" + self.solute_atoms, + solvent, + **self.rdf_init_kwargs, + exclude_same="residue", ) rdf.run(**self.rdf_run_kwargs) bins, data = rdf.results.bins, rdf.results.rdf @@ -518,10 +597,11 @@ def _prepare(self): # generate and save plots if name not in self.radii.keys(): self.radii[name] = self.kernel(bins, data, **self.kernel_kwargs) - calculated_radii = set([name for name, radius in self.radii.items() - if not np.isnan(radius)]) + calculated_radii = set( + [name for name, radius in self.radii.items() if not np.isnan(radius)] + ) missing_solvents = set(self.solvents.keys()) - calculated_radii - missing_solvents_str = ' '.join([str(i) for i in missing_solvents]) + missing_solvents_str = " ".join([str(i) for i in missing_solvents]) assert len(missing_solvents) == 0, ( f"Solute could not identify a solvation radius for " f"{missing_solvents_str}. Please manually enter missing radii " @@ -545,7 +625,10 @@ def _single_frame(self): box=self.u.dimensions, ) # make sure pairs don't include intra-molecular interactions - filter = self.solute_atoms.resindices[pairs[:, 0]] == solvent.resindices[pairs[:, 1]] + filter = ( + self.solute_atoms.resindices[pairs[:, 0]] + == solvent.resindices[pairs[:, 1]] + ) pairs = pairs[~filter] dist = dist[~filter] # replace local ids with absolute ids @@ -573,7 +656,7 @@ def _single_frame(self): dist_array, solute_res_name_array, solvent_res_name_array, - solvent_res_ix_array + solvent_res_ix_array, ) ) # add the current frame to the growing list of solvation arrays @@ -595,35 +678,60 @@ def _conclude(self): DISTANCE, SOLUTE, SOLVENT, - SOLVENT_IX - ] + SOLVENT_IX, + ], ) # clean up solvation_data df - for column in [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX, DISTANCE, SOLVENT_IX]: + for column in [ + FRAME, + SOLUTE_IX, + SOLUTE_ATOM_IX, + SOLVENT_ATOM_IX, + DISTANCE, + SOLVENT_IX, + ]: solvation_data_df[column] = pd.to_numeric(solvation_data_df[column]) - solvation_data_df = solvation_data_df.sort_values([FRAME, SOLUTE_ATOM_IX, DISTANCE]) - solvation_data_duplicates = solvation_data_df.duplicated(subset=[FRAME, SOLUTE_ATOM_IX, SOLVENT_IX]) + solvation_data_df = solvation_data_df.sort_values( + [FRAME, SOLUTE_ATOM_IX, DISTANCE] + ) + solvation_data_duplicates = solvation_data_df.duplicated( + subset=[FRAME, SOLUTE_ATOM_IX, SOLVENT_IX] + ) solvation_data = solvation_data_df[~solvation_data_duplicates] - self.solvation_data = solvation_data.set_index([FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX]) + self.solvation_data = solvation_data.set_index( + [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX] + ) duplicates = solvation_data_df[solvation_data_duplicates] - self.solvation_data_duplicates = duplicates.set_index([FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX]) + self.solvation_data_duplicates = duplicates.set_index( + [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX] + ) # instantiate analysis classes self.has_run = True analysis_classes = { - 'speciation': Speciation, - 'pairing': Pairing, - 'coordination': Coordination, - 'residence': Residence, - 'networking': Networking, + "speciation": Speciation, + "pairing": Pairing, + "coordination": Coordination, + "residence": Residence, + "networking": Networking, } for analysis_class in self.analysis_classes: - if analysis_class == 'networking': - setattr(self, 'networking', Networking.from_solute(self, self.networking_solvents)) + if analysis_class == "networking": + setattr( + self, + "networking", + Networking.from_solute(self, self.networking_solvents), + ) else: - setattr(self, analysis_class, analysis_classes[analysis_class].from_solute(self)) + setattr( + self, + analysis_class, + analysis_classes[analysis_class].from_solute(self), + ) @staticmethod - def _plot_solvation_radius(bins, data, radius): + def _plot_solvation_radius( + bins: np.ndarray, data: np.ndarray, radius: float + ) -> tuple[plt.Figure, plt.Axes]: """ Plot a solvation radius on an RDF. @@ -651,7 +759,9 @@ def _plot_solvation_radius(bins, data, radius): ax.legend() return fig, ax - def plot_solvation_radius(self, solute_name, solvent_name): + def plot_solvation_radius( + self, solute_name: str, solvent_name: str + ) -> tuple[plt.Figure, plt.Axes]: """ Plot the RDF of a solvent molecule @@ -678,7 +788,11 @@ def plot_solvation_radius(self, solute_name, solvent_name): ax.set_title(f"{self.solute_name} solvation distance for {solvent_name}") return fig, ax - def draw_molecule(self, residue, filename=None): + def draw_molecule( + self, + residue: Union[str, mda.core.groups.Residue], + filename: Optional[str] = None, + ) -> "rdkit.Chem.rdchem.Mol": """ Returns @@ -692,17 +806,22 @@ def draw_molecule(self, residue, filename=None): RDKit.Chem.rdchem.Mol """ - from rdkit.Chem import Draw - from rdkit.Chem import rdCoordGen - from rdkit.Chem.Draw.MolDrawing import DrawingOptions + if rdkit is None: + raise ImportError( + "The RDKit package is required to use this function. " + "Please install RDKit with `conda install -c conda-forge rdkit`." + ) + DrawingOptions.atomLabelFontSize = 100 if isinstance(residue, str): if residue in [self.solute_name, "solute"]: mol = self.solute_atoms.residues[0].atoms.convert_to("RDKIT") mol_mda_ix = self.solute_atoms.residues[0].atoms.ix - solute_atoms_ix0 = {solute.solute_atoms.atoms.ix[0]: solute_name - for solute_name, solute in self.atom_solutes.items()} + solute_atoms_ix0 = { + solute.solute_atoms.atoms.ix[0]: solute_name + for solute_name, solute in self.atom_solutes.items() + } for i, atom in enumerate(mol.GetAtoms()): atom_name = solute_atoms_ix0.get(mol_mda_ix[i]) label = f"{i}, " + atom_name if atom_name else str(i) @@ -712,8 +831,10 @@ def draw_molecule(self, residue, filename=None): for i, atom in enumerate(mol.GetAtoms()): atom.SetProp("atomNote", str(i)) else: - raise ValueError("If the residue is a string, it must be the name of a solute, " - "the name of a solvent, or 'solute'.") + raise ValueError( + "If the residue is a string, it must be the name of a solute, " + "the name of a solvent, or 'solute'." + ) else: assert isinstance(residue, mda.core.groups.Residue) mol = residue.atoms.convert_to("RDKIT") @@ -724,7 +845,14 @@ def draw_molecule(self, residue, filename=None): rdCoordGen.AddCoords(mol) return mol - def get_shell(self, solute_index, frame, as_df=False, remove_mols=None, closest_n_only=None): + def get_shell( + self, + solute_index: int, + frame: int, + as_df: bool = False, + remove_mols: Optional[dict[str, int]] = None, + closest_n_only: Optional[int] = None, + ) -> Union[mda.AtomGroup, pd.DataFrame]: """ Select the solvation shell of the solute. @@ -761,8 +889,9 @@ def get_shell(self, solute_index, frame, as_df=False, remove_mols=None, closest_ """ assert self.has_run, "Solute.run() must be called first." - assert frame in self.frames, ("The requested frame must be one " - "of an analyzed frames in self.frames.") + assert frame in self.frames, ( + "The requested frame must be one " "of an analyzed frames in self.frames." + ) remove_mols = {} if remove_mols is None else remove_mols # select shell of interest shell = self.solvation_data.xs((frame, solute_index), level=(FRAME, SOLUTE_IX)) @@ -774,7 +903,7 @@ def get_shell(self, solute_index, frame, as_df=False, remove_mols=None, closest_ mol_count = len(res_ix) n_remove = min(mol_count, n_remove) # then truncate resnames to remove mols - remove_ix = res_ix[(mol_count - n_remove):] + remove_ix = res_ix[(mol_count - n_remove) :] # then apply to original shell remove = shell[SOLVENT_IX].isin(remove_ix) shell = shell[np.invert(remove)] @@ -782,24 +911,28 @@ def get_shell(self, solute_index, frame, as_df=False, remove_mols=None, closest_ if closest_n_only: assert closest_n_only > 0, "closest_n_only must be at least 1" closest_n_only = min(len(shell), closest_n_only) - shell = shell[0: closest_n_only] + shell = shell[0:closest_n_only] if as_df: return shell else: return self._df_to_atom_group(shell, solute_index=solute_index) def get_closest_n_mol( - self, - solute_atom_ix, - n_mol, - guess_radius=3, - return_ordered_resix=False, - return_radii=False, - ): + self, + solute_atom_ix: int, + n_mol: int, + guess_radius: Union[float, int] = 3, + return_ordered_resix: bool = False, + return_radii: bool = False, + ) -> Union[ + mda.AtomGroup, + tuple[mda.AtomGroup, np.ndarray], + tuple[mda.AtomGroup, np.ndarray, np.ndarray], + ]: """ Select the n closest mols to the solute. - The solute is specified by it's index within solvation_data. + The solute is specified by its index within solvation_data. n is specified with the n_mol argument. Optionally returns an array of their resids and an array of the distance of the closest atom in each molecule. @@ -836,7 +969,9 @@ def get_closest_n_mol( return_radii, ) - def radial_shell(self, solute_atom_ix, radius): + def radial_shell( + self, solute_atom_ix: int, radius: Union[float, int] + ) -> mda.AtomGroup: """ Select all residues with atoms within r of the solute. @@ -857,7 +992,9 @@ def radial_shell(self, solute_atom_ix, radius): """ return get_radial_shell(self.solute_atoms[solute_atom_ix], radius) - def _df_to_atom_group(self, df, solute_index=None): + def _df_to_atom_group( + self, df: pd.DataFrame, solute_index: Optional[int] = None + ) -> mda.AtomGroup: """ Selects an MDAnalysis.AtomGroup from a pandas.DataFrame with solvent. diff --git a/solvation_analysis/speciation.py b/solvation_analysis/speciation.py index cab03d1..6ce3dbe 100644 --- a/solvation_analysis/speciation.py +++ b/solvation_analysis/speciation.py @@ -21,7 +21,14 @@ import pandas as pd -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + COUNT, +) class Speciation: @@ -51,7 +58,9 @@ class Speciation: The number of solutes in solvation_data. """ - def __init__(self, solvation_data, n_frames, n_solutes): + def __init__( + self, solvation_data: pd.DataFrame, n_frames: int, n_solutes: int + ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames self.n_solutes = n_solutes @@ -59,7 +68,7 @@ def __init__(self, solvation_data, n_frames, n_solutes): self._solvent_co_occurrence = self._solvent_co_occurrence() @staticmethod - def from_solute(solute): + def from_solute(solute: "solvation_analysis.Solute") -> "Speciation": """ Generate a Speciation object from a solute. @@ -78,8 +87,10 @@ def from_solute(solute): solute.n_solutes, ) - def _compute_speciation(self): - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + def _compute_speciation(self) -> tuple[pd.DataFrame, pd.DataFrame]: + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] counts_re = counts.reset_index([SOLVENT]) speciation_data = counts_re.pivot(columns=[SOLVENT]).fillna(0).astype(int) res_names = speciation_data.columns.levels[1] @@ -87,15 +98,19 @@ def _compute_speciation(self): sum_series = speciation_data.groupby(speciation_data.columns.to_list()).size() sum_sorted = sum_series.sort_values(ascending=False) speciation_fraction = sum_sorted.reset_index().rename(columns={0: COUNT}) - speciation_fraction[COUNT] = speciation_fraction[COUNT] / (self.n_frames * self.n_solutes) + speciation_fraction[COUNT] = speciation_fraction[COUNT] / ( + self.n_frames * self.n_solutes + ) return speciation_data, speciation_fraction @classmethod - def _mean_speciation(cls, speciation_frames, solute_number, frame_number): + def _mean_speciation( + cls, speciation_frames: pd.DataFrame, solute_number: int, frame_number: int + ) -> pd.Series: means = speciation_frames.sum(axis=1) / (solute_number * frame_number) return means - def calculate_shell_fraction(self, shell_dict): + def calculate_shell_fraction(self, shell_dict: dict[str, int]) -> float: """ Calculate the fraction of shells matching shell_dict. @@ -134,7 +149,7 @@ def calculate_shell_fraction(self, shell_dict): query_counts = self.speciation_fraction.query(query) return query_counts[COUNT].sum() - def get_shells(self, shell_dict): + def get_shells(self, shell_dict: dict[str, int]) -> pd.DataFrame: """ Find all solvation shells that match shell_dict. @@ -161,17 +176,19 @@ def get_shells(self, shell_dict): query_counts = self.speciation_data.query(query) return query_counts - def _solvent_co_occurrence(self): + def _solvent_co_occurrence(self) -> pd.DataFrame: # calculate the co-occurrence of solvent molecules. expected_solvents_list = [] actual_solvents_list = [] for solvent in self.speciation_data.columns.values: # calculate number of available coordinating solvent slots - shells_w_solvent = self.speciation_data.query(f'`{solvent}` > 0') + shells_w_solvent = self.speciation_data.query(f"`{solvent}` > 0") n_solvents = shells_w_solvent.sum() # calculate expected number of coordinating solvents n_coordination_slots = n_solvents.sum() - len(shells_w_solvent) - coordination_fraction = self.speciation_data.sum() / self.speciation_data.sum().sum() + coordination_fraction = ( + self.speciation_data.sum() / self.speciation_data.sum().sum() + ) expected_solvents = coordination_fraction * n_coordination_slots # calculate actual number of coordinating solvents actual_solvents = n_solvents.copy() @@ -192,7 +209,7 @@ def _solvent_co_occurrence(self): return correlation @property - def speciation_data(self): + def speciation_data(self) -> pd.DataFrame: """ A dataframe containing the speciation of every solute at every trajectory frame. Indexed by frame and solute numbers. @@ -202,7 +219,7 @@ def speciation_data(self): return self._speciation_df @property - def speciation_fraction(self): + def speciation_fraction(self) -> pd.DataFrame: """ The fraction of shells of each type. Columns are the solvent molecules and values are the number of solvent in the shell. @@ -212,7 +229,7 @@ def speciation_fraction(self): return self._speciation_fraction @property - def solvent_co_occurrence(self): + def solvent_co_occurrence(self) -> pd.DataFrame: """ The actual co-occurrence of solvents divided by the expected co-occurrence. In other words, given one molecule of solvent i in the shell, what is the