diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index b9953d7..1ebed68 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -34,6 +34,8 @@ jobs: cache-dependency-path: pyproject.toml - uses: pre-commit/action@v3.0.0 + with: + extra_args: --files solvation_analysis/* test: name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }} @@ -91,7 +93,7 @@ jobs: file: ./coverage.xml flags: unittests name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} - + test_pip_install: name: pip (PEP517) install on ${{ matrix.os }}, Python ${{ matrix.python-version }} runs-on: ${{ matrix.os }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 47ccbe9..f9d887e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,12 +15,6 @@ repos: args: [--remove] - id: end-of-file-fixer - id: trailing-whitespace -#- repo: https://github.com/asottile/blacken-docs -# rev: 1.16.0 -# hooks: -# - id: blacken-docs -# additional_dependencies: [black] -# exclude: README.md - repo: https://github.com/pre-commit/pygrep-hooks rev: v1.10.0 hooks: diff --git a/solvation_analysis/__init__.py b/solvation_analysis/__init__.py index 5a2f248..315d3f0 100644 --- a/solvation_analysis/__init__.py +++ b/solvation_analysis/__init__.py @@ -2,14 +2,18 @@ SolvationAnalysis An MDAnalysis rmodule for solvation analysis. """ + +from . import _version from solvation_analysis.solute import Solute # Handle versioneer from ._version import get_versions + versions = get_versions() -__version__ = versions['version'] -__git_revision__ = versions['full-revisionid'] +__version__ = versions["version"] +__git_revision__ = versions["full-revisionid"] del get_versions, versions -from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] +__all__ = ["Solute"] diff --git a/solvation_analysis/_column_names.py b/solvation_analysis/_column_names.py index af72b44..b637d5d 100644 --- a/solvation_analysis/_column_names.py +++ b/solvation_analysis/_column_names.py @@ -8,7 +8,6 @@ 1. change the variable name in all files """ - # for solvation_data FRAME = "frame" SOLUTE_IX = "solute_ix" diff --git a/solvation_analysis/_utils.py b/solvation_analysis/_utils.py index 14317ab..ab3f733 100644 --- a/solvation_analysis/_utils.py +++ b/solvation_analysis/_utils.py @@ -7,7 +7,7 @@ import MDAnalysis as mda from MDAnalysis.analysis import distances -from solvation_analysis._column_names import * +from solvation_analysis._column_names import FRAME, SOLUTE_IX, SOLVENT_IX, DISTANCE def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomGroup]: @@ -16,9 +16,9 @@ def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomG # and that the residues are all the same length # then this should work all_res_len = np.array([res.atoms.n_atoms for res in solute_atom_group.residues]) - assert np.all(all_res_len[0] == all_res_len), ( - "All residues must be the same length." - ) + assert np.all( + all_res_len[0] == all_res_len + ), "All residues must be the same length." res_atom_local_ix = defaultdict(list) res_atom_ix = defaultdict(list) @@ -26,36 +26,38 @@ def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomG res_atom_local_ix[atom.resindex].append(atom.ix - atom.residue.atoms[0].ix) res_atom_ix[atom.resindex].append(atom.index) res_occupancy = np.array([len(ix) for ix in res_atom_local_ix.values()]) - assert np.all(res_occupancy[0] == res_occupancy), ( - "All residues must have the same number of solute_atoms atoms on them." - ) + assert np.all( + res_occupancy[0] == res_occupancy + ), "All residues must have the same number of solute_atoms atoms on them." res_atom_array = np.array(list(res_atom_local_ix.values())) - assert np.all(res_atom_array[0] == res_atom_array), ( - "All residues must have the same solute_atoms atoms on them." - ) + assert np.all( + res_atom_array[0] == res_atom_array + ), "All residues must have the same solute_atoms atoms on them." res_atom_ix_array = np.array(list(res_atom_ix.values())) solute_atom_group_dict = {} for i in range(0, res_atom_ix_array.shape[1]): - solute_atom_group_dict[i] = solute_atom_group.universe.atoms[res_atom_ix_array[:, i]] + solute_atom_group_dict[i] = solute_atom_group.universe.atoms[ + res_atom_ix_array[:, i] + ] return solute_atom_group_dict -def verify_solute_atoms_dict(solute_atoms_dict: dict[str, mda.AtomGroup]) -> mda.AtomGroup: +def verify_solute_atoms_dict( + solute_atoms_dict: dict[str, mda.AtomGroup], +) -> mda.AtomGroup: # first we verify the input format atom_group_lengths = [] for solute_name, solute_atom_group in solute_atoms_dict.items(): - assert isinstance(solute_name, str), ( - "The keys of solutes_dict must be strings." - ) + assert isinstance(solute_name, str), "The keys of solutes_dict must be strings." assert isinstance(solute_atom_group, mda.AtomGroup), ( f"The values of solutes_dict must be MDAnalysis.AtomGroups. But the value" f"for {solute_name} is a {type(solute_atom_group)}." ) - assert len(solute_atom_group) == len(solute_atom_group.residues), ( - "The solute_atom_group must have a single atom on each residue." - ) + assert len(solute_atom_group) == len( + solute_atom_group.residues + ), "The solute_atom_group must have a single atom on each residue." atom_group_lengths.append(len(solute_atom_group)) assert np.all(np.array(atom_group_lengths) == atom_group_lengths[0]), ( "AtomGroups in solutes_dict must have the same length because there should be" @@ -63,16 +65,25 @@ def verify_solute_atoms_dict(solute_atoms_dict: dict[str, mda.AtomGroup]) -> mda ) # verify that the solute_atom_groups have no overlap - solute_atom_group = reduce(lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()]) - assert solute_atom_group.n_atoms == sum([atoms.n_atoms for atoms in solute_atoms_dict.values()]), ( - "The solute_atom_groups must not overlap." + solute_atom_group = reduce( + lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()] ) + assert solute_atom_group.n_atoms == sum( + [atoms.n_atoms for atoms in solute_atoms_dict.values()] + ), "The solute_atom_groups must not overlap." verify_solute_atoms(solute_atom_group) return solute_atom_group -def get_atom_group(selection: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup]) -> mda.AtomGroup: +def get_atom_group( + selection: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], +) -> mda.AtomGroup: """ Cast an MDAnalysis.Atom, MDAnalysis.Residue, or MDAnalysis.ResidueGroup to AtomGroup. @@ -103,14 +114,22 @@ def get_atom_group(selection: Union[mda.core.groups.Residue, mda.core.groups.Res return selection - def get_closest_n_mol( - central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup], + central_species: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], n_mol: int, guess_radius: Union[float, int] = 3, return_ordered_resix: bool = False, return_radii: bool = False, -) -> Union[mda.AtomGroup, tuple[mda.AtomGroup, np.ndarray], tuple[mda.AtomGroup, np.ndarray, np.ndarray]]: +) -> Union[ + mda.AtomGroup, + tuple[mda.AtomGroup, np.ndarray], + tuple[mda.AtomGroup, np.ndarray, np.ndarray], +]: """ Returns the closest n molecules to the central species, an array of their resix, and an array of the distance of the closest atom in each molecule. @@ -152,7 +171,7 @@ def get_closest_n_mol( n_mol, guess_radius + 1, return_ordered_resix=return_ordered_resix, - return_radii=return_radii + return_radii=return_radii, ) radii = distances.distance_array(coords, partial_shell.positions, box=u.dimensions)[ 0 @@ -160,7 +179,7 @@ def get_closest_n_mol( ordering = np.argsort(radii) ordered_resix = shell_resix[ordering] closest_n_resix = np.sort(np.unique(ordered_resix, return_index=True)[1])[ - 0: n_mol + 1 + 0 : n_mol + 1 ] str_resix = " ".join(str(resix) for resix in ordered_resix[closest_n_resix]) full_shell = u.select_atoms(f"resindex {str_resix}") @@ -179,8 +198,13 @@ def get_closest_n_mol( def get_radial_shell( - central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup], - radius: Union[float, int] + central_species: Union[ + mda.core.groups.Residue, + mda.core.groups.ResidueGroup, + mda.core.groups.Atom, + mda.core.groups.AtomGroup, + ], + radius: Union[float, int], ) -> mda.AtomGroup: """ Returns all molecules with atoms within the radius of the central species. diff --git a/solvation_analysis/_version.py b/solvation_analysis/_version.py index 16cd611..6bd04f5 100644 --- a/solvation_analysis/_version.py +++ b/solvation_analysis/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -58,18 +57,19 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate # pylint:disable=too-many-arguments,consider-using-with # noqa -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -77,10 +77,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -115,15 +118,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -182,7 +191,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -191,7 +200,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -199,24 +208,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -231,8 +247,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -240,10 +255,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -258,8 +282,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -299,17 +322,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -318,10 +340,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -370,8 +394,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -400,8 +423,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -544,11 +566,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -572,9 +596,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -588,8 +616,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -598,13 +625,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -618,6 +648,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/solvation_analysis/coordination.py b/solvation_analysis/coordination.py index 2fa4a5b..0f49008 100644 --- a/solvation_analysis/coordination.py +++ b/solvation_analysis/coordination.py @@ -18,8 +18,17 @@ import pandas as pd import MDAnalysis as mda +import solvation_analysis -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + SOLVENT_ATOM_IX, + ATOM_TYPE, + FRACTION, +) class Coordination: @@ -69,7 +78,7 @@ def __init__( n_frames: int, n_solutes: int, solvent_counts: dict[str, int], - atom_group: mda.core.groups.AtomGroup + atom_group: mda.core.groups.AtomGroup, ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames @@ -81,7 +90,7 @@ def __init__( self._coordination_vs_random = self._calculate_coordination_vs_random() @staticmethod - def from_solute(solute: 'Solute') -> 'Coordination': + def from_solute(solute: "solvation_analysis.Solute") -> "Coordination": """ Generate a Coordination object from a solute. @@ -103,9 +112,11 @@ def from_solute(solute: 'Solute') -> 'Coordination': ) def _mean_cn(self) -> tuple[dict[str, float], pd.DataFrame]: - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] cn_series = counts.groupby([SOLVENT, FRAME]).sum() / ( - self.n_solutes * self.n_frames + self.n_solutes * self.n_frames ) cn_by_frame = cn_series.unstack() cn_dict = cn_series.groupby([SOLVENT]).sum().to_dict() @@ -118,22 +129,23 @@ def _calculate_coordinating_atoms(self, tol: float = 0.005) -> pd.DataFrame: """ # lookup atom types atom_types = self.solvation_data.reset_index([SOLVENT_ATOM_IX]) - atom_types[ATOM_TYPE] = self.atom_group[atom_types[SOLVENT_ATOM_IX].values].types + atom_types[ATOM_TYPE] = self.atom_group[ + atom_types[SOLVENT_ATOM_IX].values + ].types # count atom types atoms_by_type = atom_types[[ATOM_TYPE, SOLVENT, SOLVENT_ATOM_IX]] type_counts = atoms_by_type.groupby([SOLVENT, ATOM_TYPE]).count() solvent_counts = type_counts.groupby([SOLVENT]).sum()[SOLVENT_ATOM_IX] # calculate fraction of each solvent_counts_list = [ - solvent_counts[solvent] for solvent in - type_counts.index.get_level_values(SOLVENT) + solvent_counts[solvent] + for solvent in type_counts.index.get_level_values(SOLVENT) ] type_fractions = type_counts[SOLVENT_ATOM_IX] / solvent_counts_list type_fractions.name = FRACTION # change index type type_fractions = ( - type_fractions - .reset_index(ATOM_TYPE) + type_fractions.reset_index(ATOM_TYPE) .astype({ATOM_TYPE: str}) .set_index(ATOM_TYPE, append=True) ) diff --git a/solvation_analysis/networking.py b/solvation_analysis/networking.py index e2ee2ee..19da5c2 100644 --- a/solvation_analysis/networking.py +++ b/solvation_analysis/networking.py @@ -16,6 +16,7 @@ as an attribute of the Solute class. This makes instantiating it and calculating the solvation data a non-issue. """ + from typing import Union import pandas as pd @@ -23,8 +24,18 @@ from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components +import solvation_analysis from solvation_analysis._utils import calculate_adjacency_dataframe -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + SOLVENT, + SOLVENT_IX, + SOLUTE_IX, + NETWORK, + PAIRED, + NETWORKED, + ISOLATED, +) class Networking: @@ -68,12 +79,12 @@ def __init__( solvents: Union[str, list[str]], solvation_data: pd.DataFrame, solute_res_ix: np.ndarray, - res_name_map: pd.Series + res_name_map: pd.Series, ) -> None: self.solvents = solvents self.solvation_data = solvation_data - solvent_present = np.isin(self.solvents, self.solvation_data[SOLVENT].unique()) # TODO: we need all analysis classes to run when there is no solvation_data + # solvent_present = np.isin(self.solvents, self.solvation_data[SOLVENT].unique()) # if not solvent_present.all(): # raise Exception(f"Solvent(s) {np.array(self.solvents)[~solvent_present]} not found in solvation data.") self.solute_res_ix = solute_res_ix @@ -81,11 +92,15 @@ def __init__( self.n_solute = len(solute_res_ix) self._network_df = self._generate_networks() self._network_sizes = self._calculate_network_sizes() - self._solute_status, self._solute_status_by_frame = self._calculate_solute_status() + self._solute_status, self._solute_status_by_frame = ( + self._calculate_solute_status() + ) self._solute_status = self._solute_status.to_dict() @staticmethod - def from_solute(solute: 'Solute', solvents: Union[str, list[str]]) -> 'Networking': + def from_solute( + solute: "solvation_analysis.Solute", solvents: Union[str, list[str]] + ) -> "Networking": """ Generate a Networking object from a solute and solvent names. @@ -129,7 +144,9 @@ def _generate_networks(self) -> pd.DataFrame: 3. tabulate the solvent involved in each network and store in a DataFrame """ solvents = [self.solvents] if isinstance(self.solvents, str) else self.solvents - solvation_subset = self.solvation_data[np.isin(self.solvation_data[SOLVENT], solvents)] + solvation_subset = self.solvation_data[ + np.isin(self.solvation_data[SOLVENT], solvents) + ] # create adjacency matrix from solvation_subset graph = calculate_adjacency_dataframe(solvation_subset) network_arrays = [] @@ -143,16 +160,16 @@ def _generate_networks(self) -> pd.DataFrame: ix_to_res_ix = np.concatenate([solvent_map, solute_map]) adjacency_df = Networking._unwrap_adjacency_dataframe(df) _, network = connected_components( - csgraph=adjacency_df, - directed=False, - return_labels=True + csgraph=adjacency_df, directed=False, return_labels=True ) - network_array = np.vstack([ - np.full(len(network), frame), # frame - network, # network - self.res_name_map[ix_to_res_ix], # res_names - ix_to_res_ix, # res index - ]).T + network_array = np.vstack( + [ + np.full(len(network), frame), # frame + network, # network + self.res_name_map[ix_to_res_ix], # res_names + ix_to_res_ix, # res index + ] + ).T network_arrays.append(network_array) # create and return network dataframe if len(network_arrays) == 0: @@ -161,8 +178,8 @@ def _generate_networks(self) -> pd.DataFrame: all_clusters = np.concatenate(network_arrays) cluster_df = ( pd.DataFrame(all_clusters, columns=[FRAME, NETWORK, SOLVENT, SOLVENT_IX]) - .set_index([FRAME, NETWORK]) - .sort_values([FRAME, NETWORK]) + .set_index([FRAME, NETWORK]) + .sort_values([FRAME, NETWORK]) ) return cluster_df @@ -170,8 +187,12 @@ def _calculate_network_sizes(self) -> pd.DataFrame: # This utility calculates the network sizes and returns a convenient dataframe. cluster_df = self.network_df cluster_sizes = cluster_df.groupby([FRAME, NETWORK]).count() - size_counts = cluster_sizes.groupby([FRAME, SOLVENT]).count().unstack(fill_value=0) - size_counts.columns = size_counts.columns.droplevel(None) # the column value is None + size_counts = ( + cluster_sizes.groupby([FRAME, SOLVENT]).count().unstack(fill_value=0) + ) + size_counts.columns = size_counts.columns.droplevel( + None + ) # the column value is None return size_counts def _calculate_solute_status(self) -> tuple[pd.Series, pd.DataFrame]: @@ -184,7 +205,9 @@ def _calculate_solute_status(self) -> tuple[pd.Series, pd.DataFrame]: status = self.network_sizes.iloc[:, 0:0] status[PAIRED] = self.network_sizes.iloc[:, 0:1].sum(axis=1).astype(int) status[NETWORKED] = self.network_sizes.iloc[:, 1:].sum(axis=1).astype(int) - status[ISOLATED] = self.n_solute - status.loc[:, [PAIRED, NETWORKED]].sum(axis=1) + status[ISOLATED] = self.n_solute - status.loc[:, [PAIRED, NETWORKED]].sum( + axis=1 + ) status = status.loc[:, [ISOLATED, PAIRED, NETWORKED]] solute_status_by_frame = status / self.n_solute solute_status = solute_status_by_frame.mean() @@ -221,7 +244,9 @@ def get_network_res_ix(self, network_index: int, frame: int) -> np.ndarray: """ - res_ix = self.network_df.loc[pd.IndexSlice[frame, network_index], SOLVENT_IX].values + res_ix = self.network_df.loc[ + pd.IndexSlice[frame, network_index], SOLVENT_IX + ].values return res_ix.astype(int) @property diff --git a/solvation_analysis/pairing.py b/solvation_analysis/pairing.py index 9aa02c4..42c9b9b 100644 --- a/solvation_analysis/pairing.py +++ b/solvation_analysis/pairing.py @@ -19,7 +19,14 @@ import pandas as pd import numpy as np -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + DISTANCE, +) class Pairing: @@ -61,7 +68,7 @@ def __init__( solvation_data: pd.DataFrame, n_frames: int, n_solutes: int, - n_solvents: dict[str, int] + n_solvents: dict[str, int], ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames @@ -69,10 +76,14 @@ def __init__( self.solvent_counts = n_solvents self._solvent_pairing, self._pairing_by_frame = self._fraction_coordinated() self._fraction_free_solvents = self._fraction_free_solvent() - self._diluent_composition, self._diluent_composition_by_frame, self._diluent_counts = self._diluent_composition() + ( + self._diluent_composition, + self._diluent_composition_by_frame, + self._diluent_counts, + ) = self._diluent_composition() @staticmethod - def from_solute(solute: 'Solute') -> 'Pairing': + def from_solute(solute: "solvation_analysis.Solute") -> "Pairing": """ Generate a Pairing object from a solute. @@ -89,12 +100,14 @@ def from_solute(solute: 'Solute') -> 'Pairing': solute.solvation_data, solute.n_frames, solute.n_solutes, - solute.solvent_counts + solute.solvent_counts, ) def _fraction_coordinated(self) -> tuple[dict[str, float], pd.DataFrame]: # calculate the fraction of solute coordinated with each solvent - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] pairing_series = counts.astype(bool).groupby([SOLVENT, FRAME]).sum() / ( self.n_solutes ) # mean coordinated overall @@ -105,14 +118,22 @@ def _fraction_coordinated(self) -> tuple[dict[str, float], pd.DataFrame]: def _fraction_free_solvent(self) -> dict[str, float]: # calculate the fraction of each solvent NOT coordinated with the solute - counts = self.solvation_data.groupby([FRAME, SOLVENT_IX, SOLVENT]).count()[DISTANCE] + counts = self.solvation_data.groupby([FRAME, SOLVENT_IX, SOLVENT]).count()[ + DISTANCE + ] totals = counts.groupby([SOLVENT]).count() / self.n_frames - n_solvents = np.array([self.solvent_counts[name] for name in totals.index.values]) + n_solvents = np.array( + [self.solvent_counts[name] for name in totals.index.values] + ) free_solvents = np.ones(len(totals)) - totals / n_solvents return free_solvents.to_dict() - def _diluent_composition(self) -> tuple[dict[str, float], pd.DataFrame, pd.DataFrame]: - coordinated_solvents = self.solvation_data.groupby([FRAME, SOLVENT]).nunique()[SOLVENT_IX] + def _diluent_composition( + self, + ) -> tuple[dict[str, float], pd.DataFrame, pd.DataFrame]: + coordinated_solvents = self.solvation_data.groupby([FRAME, SOLVENT]).nunique()[ + SOLVENT_IX + ] solvent_counts = pd.Series(self.solvent_counts) total_solvents = solvent_counts.reindex(coordinated_solvents.index, level=1) diluent_solvents = total_solvents - coordinated_solvents diff --git a/solvation_analysis/plotting.py b/solvation_analysis/plotting.py index 30bbf6c..22aa587 100644 --- a/solvation_analysis/plotting.py +++ b/solvation_analysis/plotting.py @@ -13,7 +13,6 @@ from typing import Union, Optional, Any, Callable from copy import deepcopy -import plotly import plotly.graph_objects as go import plotly.express as px import pandas as pd @@ -39,7 +38,7 @@ def plot_network_size_histogram(networking: Union[Networking, Solute]) -> go.Fig """ if isinstance(networking, Solute): if not hasattr(networking, "networking"): - raise ValueError(f"Solute networking analysis class must be instantiated.") + raise ValueError("Solute networking analysis class must be instantiated.") networking = networking.networking network_sizes = networking.network_sizes sums = network_sizes.sum(axis=0) @@ -71,7 +70,7 @@ def plot_shell_composition_by_size(speciation: Union[Speciation, Solute]) -> go. """ if isinstance(speciation, Solute): if not hasattr(speciation, "speciation"): - raise ValueError(f"Solute speciation analysis class must be instantiated.") + raise ValueError("Solute speciation analysis class must be instantiated.") speciation = speciation.speciation speciation_data = speciation.speciation_data.copy() speciation_data["total"] = speciation_data.sum(axis=1) @@ -92,11 +91,13 @@ def plot_shell_composition_by_size(speciation: Union[Speciation, Solute]) -> go. return fig -def plot_co_occurrence(speciation: Union[Speciation, Solute], colorscale: Optional[Any] = None) -> go.Figure: +def plot_co_occurrence( + speciation: Union[Speciation, Solute], colorscale: Optional[Any] = None +) -> go.Figure: """ Plot the co-occurrence matrix of the solute using Plotly. - Co-occurrence represents the extent to which solvents occur with eachother + Co-occurrence represents the extent to which solvents occur with each other relative to random. Values higher than 1 mean that solvents occur together more often than random and values lower than 1 mean solvents occur together less often than random. "Random" is calculated based on the total number of @@ -113,7 +114,7 @@ def plot_co_occurrence(speciation: Union[Speciation, Solute], colorscale: Option """ if isinstance(speciation, Solute): if not hasattr(speciation, "speciation"): - raise ValueError(f"Solute speciation analysis class must be instantiated.") + raise ValueError("Solute speciation analysis class must be instantiated.") speciation = speciation.speciation solvent_names = speciation.speciation_data.columns.values @@ -126,9 +127,9 @@ def plot_co_occurrence(speciation: Union[Speciation, Solute], colorscale: Option range_val = max_val - min_val colorscale = [ - [0, 'rgb(67,147,195)'], + [0, "rgb(67,147,195)"], [(1 - min_val) / range_val, "white"], - [1, 'rgb(214,96,77)'] + [1, "rgb(214,96,77)"], ] # Create a heatmap trace with text annotations @@ -139,24 +140,24 @@ def plot_co_occurrence(speciation: Union[Speciation, Solute], colorscale: Option text=speciation.solvent_co_occurrence.round(2).to_numpy(dtype=str), # Keep the text annotations in the original order hoverinfo="none", - colorscale=colorscale + colorscale=colorscale, ) # Update layout to display tick labels and text annotations layout = go.Layout( title="Solvent Co-Occurrence Matrix", xaxis=dict( - tickmode='array', + tickmode="array", tickvals=list(range(len(solvent_names))), ticktext=solvent_names, tickangle=-30, - side='top' + side="top", ), yaxis=dict( - tickmode='array', + tickmode="array", tickvals=list(range(len(solvent_names))), ticktext=solvent_names, - autorange='reversed' + autorange="reversed", ), margin=dict(l=60, r=60, b=60, t=100, pad=4), annotations=[ @@ -165,7 +166,7 @@ def plot_co_occurrence(speciation: Union[Speciation, Solute], colorscale: Option y=j, text=str(round(speciation.solvent_co_occurrence.iloc[j, i], 2)), font=dict(size=14, color="black"), - showarrow=False + showarrow=False, ) for i in range(len(solvent_names)) for j in range(len(solvent_names)) @@ -228,7 +229,6 @@ def compare_solvent_dicts( set(solution_dict.keys()) for solution_dict in property_dict.values() ] valid_solvents = set.intersection(*all_solvents) - invalid_solvents = set.union(*all_solvents) - valid_solvents if not set(solvents_to_plot).issubset(valid_solvents): raise Exception( f"solvents_to_plot must only include solvents that are " @@ -331,7 +331,7 @@ def compare_func( return fig arguments_docstring = """ - + property_dict : dict of {str: dict} a dictionary with the solution name as keys and a dict of {str: float} as values, where each key is the name of the solvent of each solution and each value is the property of interest diff --git a/solvation_analysis/rdf_parser.py b/solvation_analysis/rdf_parser.py index ebb8f41..e530899 100644 --- a/solvation_analysis/rdf_parser.py +++ b/solvation_analysis/rdf_parser.py @@ -19,10 +19,10 @@ import warnings from scipy.signal import find_peaks, gaussian -from solvation_analysis._column_names import * - -def interpolate_rdf(bins: np.ndarray, rdf: np.ndarray, floor: float = 0.05, cutoff: float = 5) -> tuple[UnivariateSpline, tuple[float, float]]: +def interpolate_rdf( + bins: np.ndarray, rdf: np.ndarray, floor: float = 0.05, cutoff: float = 5 +) -> tuple[UnivariateSpline, tuple[float, float]]: """ Fits a sciply.interpolate.UnivariateSpline to the starting region of the RDF. The floor and cutoff control the region of the RDF that the @@ -77,7 +77,9 @@ def identify_minima(f: UnivariateSpline) -> tuple[np.ndarray, np.ndarray]: return cr_pts, cr_vals -def plot_interpolation_fit(bins: np.ndarray, rdf: np.ndarray, **kwargs: Any) -> tuple[plt.Figure, plt.Axes]: +def plot_interpolation_fit( + bins: np.ndarray, rdf: np.ndarray, **kwargs: Any +) -> tuple[plt.Figure, plt.Axes]: """ Calls interpolate_rdf and identify_minima to identify the extrema of an RDF. Plots the original rdf, the interpolated spline, and the extrema of the @@ -111,7 +113,9 @@ def plot_interpolation_fit(bins: np.ndarray, rdf: np.ndarray, **kwargs: Any) -> return fig, ax -def good_cutoff(cutoff_region: tuple[float, float], cr_pts: np.ndarray, cr_vals: np.ndarray) -> bool: +def good_cutoff( + cutoff_region: tuple[float, float], cr_pts: np.ndarray, cr_vals: np.ndarray +) -> bool: """ Uses several heuristics to determine if the a solvation cutoff is valid solvation cutoff. This fails if there is no solvation shell. @@ -141,13 +145,20 @@ def good_cutoff(cutoff_region: tuple[float, float], cr_pts: np.ndarray, cr_vals: return True -def good_cutoff_scipy(cutoff_region: tuple[float, float], min_trough_depth: float, peaks: np.ndarray, troughs: np.ndarray, rdf: np.ndarray, bins: np.ndarray) -> bool: +def good_cutoff_scipy( + cutoff_region: tuple[float, float], + min_trough_depth: float, + peaks: np.ndarray, + troughs: np.ndarray, + rdf: np.ndarray, + bins: np.ndarray, +) -> bool: """ Uses several heuristics to determine if the solvation cutoff is valid solvation cutoff. This fails if there is no solvation shell. Heuristics: - - trough follows peak + - troughs follows peaks - in `Solute.cutoff_region` (specified by kwarg) - normalized peak height > 0.05 @@ -156,7 +167,7 @@ def good_cutoff_scipy(cutoff_region: tuple[float, float], min_trough_depth: floa cutoff_region : tuple boundaries in which to search for a solvation shell cutoff, i.e. (1.5, 4) min_trough_depth : float - the minimum depth of a trough to be considered a valid solvation cutoff + the minimum depth to be considered a valid solvation cutoff peaks : np.array the indices of the peaks in the bins array troughs : np.array @@ -173,16 +184,22 @@ def good_cutoff_scipy(cutoff_region: tuple[float, float], min_trough_depth: floa # normalize rdf norm_rdf = rdf / np.max(rdf) if ( - len(peaks) == 0 or len(troughs) == 0 # insufficient critical points + len(peaks) == 0 + or len(troughs) == 0 # insufficient critical points or troughs[0] < peaks[0] # not a min and max - or not (cutoff_region[0] < bins[troughs[0]] < cutoff_region[1]) # min not in cutoff - or abs(norm_rdf[peaks[0]] - norm_rdf[troughs[0]]) < min_trough_depth # peak too small + or not ( + cutoff_region[0] < bins[troughs[0]] < cutoff_region[1] + ) # min not in cutoff + or abs(norm_rdf[peaks[0]] - norm_rdf[troughs[0]]) + < min_trough_depth # peak too small ): return False return True -def scipy_find_peaks_troughs(bins: np.ndarray, rdf: np.ndarray, return_rdf: bool = False, **kwargs: Any) -> Union[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray, np.ndarray]]: +def scipy_find_peaks_troughs( + bins: np.ndarray, rdf: np.ndarray, return_rdf: bool = False, **kwargs: Any +) -> Union[tuple[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Finds the indices of the peaks and troughs of an RDF. @@ -231,7 +248,7 @@ def identify_cutoff_scipy( failure_behavior: str = "warn", min_trough_depth: float = 0.02, default: Optional[float] = None, - **kwargs: Any + **kwargs: Any, ) -> Optional[float]: """ Identifies the solvation cutoff of an RDF. @@ -254,7 +271,7 @@ def identify_cutoff_scipy( default : float, optional the value to return if no solvation shell is found min_trough_depth : float - the minimum depth of a trough to be considered a valid solvation cutoff + the minimum depth of troughs to be considered a valid solvation cutoff kwargs : passed to the scipy.find_peaks function Returns @@ -263,16 +280,20 @@ def identify_cutoff_scipy( the solvation cutoff of the RDF """ peaks, troughs = scipy_find_peaks_troughs(bins, rdf, **kwargs) - if not good_cutoff_scipy(cutoff_region, min_trough_depth, peaks, troughs, rdf, bins): + if not good_cutoff_scipy( + cutoff_region, min_trough_depth, peaks, troughs, rdf, bins + ): if failure_behavior == "silent": return default if failure_behavior == "warn": warnings.warn("No solvation shell detected.") return default if failure_behavior == "exception": - raise RuntimeError("Solute could not identify a solvation radius for at least one solvent. " - "Please enter the missing radii manually by adding them to the radii dict" - "and rerun the analysis.") + raise RuntimeError( + "Solute could not identify a solvation radius for at least one solvent. " + "Please enter the missing radii manually by adding them to the radii dict" + "and rerun the analysis." + ) cutoff = bins[troughs[0]] return cutoff @@ -302,7 +323,9 @@ def plot_scipy_find_peaks_troughs( fig, ax : matplotlib pyplot Figure and Axis for the fit """ - peaks, troughs, smooth_rdf = scipy_find_peaks_troughs(bins, rdf, return_rdf=True, **kwargs) + peaks, troughs, smooth_rdf = scipy_find_peaks_troughs( + bins, rdf, return_rdf=True, **kwargs + ) fig, ax = plt.subplots() ax.plot(bins, rdf, "b--", label="rdf") ax.plot(bins, smooth_rdf, "g-", label="smooth_rdf") @@ -321,7 +344,7 @@ def identify_cutoff_poly( failure_behavior: str = "warn", cutoff_region: tuple[float, float] = (1.5, 4), floor: float = 0.05, - cutoff: float = 5 + cutoff: float = 5, ) -> float: """ Identifies the solvation cutoff of an RDF using a polynomial interpolation. @@ -356,7 +379,9 @@ def identify_cutoff_poly( warnings.warn("No solvation shell detected.") return np.NaN if failure_behavior == "exception": - raise RuntimeError("Solute could not identify a solvation radius for at least one solvent. " - "Please enter the missing radii manually by adding them to the radii dict" - "and rerun the analysis.") + raise RuntimeError( + "Solute could not identify a solvation radius for at least one solvent. " + "Please enter the missing radii manually by adding them to the radii dict" + "and rerun the analysis." + ) return cr_pts[1] diff --git a/solvation_analysis/residence.py b/solvation_analysis/residence.py index 0410b2b..6577acb 100644 --- a/solvation_analysis/residence.py +++ b/solvation_analysis/residence.py @@ -15,6 +15,7 @@ as an attribute of the Solute class. This makes instantiating it and calculating the solvation data a non-issue. """ + import math import warnings @@ -24,7 +25,13 @@ from statsmodels.tsa.stattools import acovf from scipy.optimize import curve_fit -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + SOLVENT, + SOLUTE_ATOM_IX, + SOLVENT_ATOM_IX, + SOLUTE_IX, +) from solvation_analysis._utils import calculate_adjacency_dataframe @@ -86,14 +93,15 @@ class Residence: def __init__(self, solvation_data: pd.DataFrame, step: int) -> None: self.solvation_data = solvation_data self._auto_covariances = self._calculate_auto_covariance_dict() - self._residence_times_cutoff = self._calculate_residence_times_with_cutoff(self._auto_covariances, step) - self._residence_times_fit, self._fit_parameters = self._calculate_residence_times_with_fit( - self._auto_covariances, - step + self._residence_times_cutoff = self._calculate_residence_times_with_cutoff( + self._auto_covariances, step + ) + self._residence_times_fit, self._fit_parameters = ( + self._calculate_residence_times_with_fit(self._auto_covariances, step) ) @staticmethod - def from_solute(solute: 'Solute') -> 'Residence': + def from_solute(solute: "solvation_analysis.Solute") -> "Residence": """ Generate a Residence object from a solute. @@ -106,15 +114,14 @@ def from_solute(solute: 'Solute') -> 'Residence': Residence """ assert solute.has_run, "The solute must be run before calling from_solute" - return Residence( - solute.solvation_data, - solute.step - ) + return Residence(solute.solvation_data, solute.step) def _calculate_auto_covariance_dict(self) -> dict[str, np.ndarray]: partial_index = self.solvation_data.index.droplevel(SOLVENT_ATOM_IX) unique_indices = np.unique(partial_index) - frame_solute_index = pd.MultiIndex.from_tuples(unique_indices, names=partial_index.names) + frame_solute_index = pd.MultiIndex.from_tuples( + unique_indices, names=partial_index.names + ) auto_covariance_dict = {} for res_name, res_solvation_data in self.solvation_data.groupby([SOLVENT]): if isinstance(res_name, tuple): @@ -128,14 +135,20 @@ def _calculate_auto_covariance_dict(self) -> dict[str, np.ndarray]: return auto_covariance_dict @staticmethod - def _calculate_residence_times_with_cutoff(auto_covariances: dict[str, np.ndarray], step: int, convergence_cutoff: float = 0.1) -> dict[str, float]: + def _calculate_residence_times_with_cutoff( + auto_covariances: dict[str, np.ndarray], + step: int, + convergence_cutoff: float = 0.1, + ) -> dict[str, float]: residence_times = {} for res_name, auto_covariance in auto_covariances.items(): if np.min(auto_covariance) > convergence_cutoff: residence_times[res_name] = np.nan - warnings.warn(f'the autocovariance for {res_name} does not converge to zero ' - 'so a residence time cannot be calculated. A longer simulation ' - 'is required to get a valid estimate of the residence time.') + warnings.warn( + f"the autocovariance for {res_name} does not converge to zero " + "so a residence time cannot be calculated. A longer simulation " + "is required to get a valid estimate of the residence time." + ) unassigned = True for frame, val in enumerate(auto_covariance): if val < 1 / math.e: @@ -147,13 +160,18 @@ def _calculate_residence_times_with_cutoff(auto_covariances: dict[str, np.ndarra return residence_times @staticmethod - def _calculate_residence_times_with_fit(auto_covariances: dict[str, np.ndarray], step: int) -> tuple[dict[str, float], dict[str, tuple[float, float, float]]]: + def _calculate_residence_times_with_fit( + auto_covariances: dict[str, np.ndarray], step: int + ) -> tuple[dict[str, float], dict[str, tuple[float, float, float]]]: # calculate the residence times residence_times = {} fit_parameters = {} for res_name, auto_covariance in auto_covariances.items(): res_time, params = Residence._fit_exponential(auto_covariance, res_name) - residence_times[res_name], fit_parameters[res_name] = round(res_time * step, 2), params + residence_times[res_name], fit_parameters[res_name] = ( + round(res_time * step, 2), + params, + ) return residence_times, fit_parameters def plot_auto_covariance(self, res_name: str) -> tuple[plt.Figure, plt.Axes]: @@ -175,16 +193,22 @@ def plot_auto_covariance(self, res_name: str) -> tuple[plt.Figure, plt.Axes]: auto_covariance = self.auto_covariances[res_name] frames = np.arange(len(auto_covariance)) params = self.fit_parameters[res_name] - exp_func = lambda x: self._exponential_decay(x, *params) + + def exp_func(x): + return self._exponential_decay(x, *params) + exp_fit = np.array(map(exp_func, frames)) fig, ax = plt.subplots() ax.plot(frames, auto_covariance, "b-", label="auto covariance") try: ax.scatter(frames, exp_fit, label="exponential fit") - except: - warnings.warn(f'The fit for {res_name} failed so the exponential ' - f'fit will not be plotted.') - ax.hlines(y=1/math.e, xmin=frames[0], xmax=frames[-1], label='1/e cutoff') + # TODO:check this + except RuntimeError: + warnings.warn( + f"The fit for {res_name} failed so the exponential " + f"fit will not be plotted." + ) + ax.hlines(y=1 / math.e, xmin=frames[0], xmax=frames[-1], label="1/e cutoff") ax.set_xlabel("Timestep (frames)") ax.set_ylabel("Normalized Autocovariance") ax.set_ylim(0, 1) @@ -208,7 +232,9 @@ def _exponential_decay(x: np.ndarray, a: float, b: float, c: float) -> np.ndarra return a * np.exp(-b * x) + c @staticmethod - def _fit_exponential(auto_covariance: np.ndarray, res_name: str) -> tuple[float, tuple[float, float, float]]: + def _fit_exponential( + auto_covariance: np.ndarray, res_name: str + ) -> tuple[float, tuple[float, float, float]]: auto_covariance_norm = auto_covariance / auto_covariance[0] try: params, param_covariance = curve_fit( @@ -219,9 +245,11 @@ def _fit_exponential(auto_covariance: np.ndarray, res_name: str) -> tuple[float, ) tau = 1 / params[1] # p except RuntimeError: - warnings.warn(f'The fit for {res_name} failed so its values in' - f'residence_time_fits and fit_parameters will be' - f'set to np.nan.') + warnings.warn( + f"The fit for {res_name} failed so its values in" + f"residence_time_fits and fit_parameters will be" + f"set to np.nan." + ) tau, params = np.nan, (np.nan, np.nan, np.nan) return tau, params @@ -230,17 +258,21 @@ def _calculate_auto_covariance(adjacency_matrix: pd.DataFrame) -> np.ndarray: auto_covariances = [] timesteps = adjacency_matrix.index.levels[0] - for solute_ix, solute_df in adjacency_matrix.groupby([SOLUTE_IX, SOLUTE_ATOM_IX]): + for solute_ix, solute_df in adjacency_matrix.groupby( + [SOLUTE_IX, SOLUTE_ATOM_IX] + ): # this is needed to make sure auto-covariances can be concatenated later - new_solute_df = solute_df.droplevel([SOLUTE_IX, SOLUTE_ATOM_IX]).reindex(timesteps, fill_value=0) + new_solute_df = solute_df.droplevel([SOLUTE_IX, SOLUTE_ATOM_IX]).reindex( + timesteps, fill_value=0 + ) non_zero_cols = new_solute_df.loc[:, (solute_df != 0).any(axis=0)] auto_covariance_df = non_zero_cols.apply( acovf, axis=0, - result_type='expand', + result_type="expand", demean=False, adjusted=True, - fft=True + fft=True, ) # timesteps with no binding are getting skipped, we need to make sure to include all timesteps auto_covariances.append(auto_covariance_df.values) @@ -278,6 +310,6 @@ def residence_times_fit(self) -> dict[str, float]: def fit_parameters(self) -> dict[str, tuple[float, float, float]]: """ A dictionary where keys are residue names and values are the - arameters for the exponential fit to the autocorrelation function. + parameters for the exponential fit to the autocorrelation function. """ return self._fit_parameters diff --git a/solvation_analysis/solute.py b/solvation_analysis/solute.py index 82e5e20..9165496 100644 --- a/solvation_analysis/solute.py +++ b/solvation_analysis/solute.py @@ -105,6 +105,7 @@ solvation shell, returning an AtomGroup for visualization or further analysis. This is covered in the visualization tutorial. """ + from collections import defaultdict from functools import reduce from typing import Any, Callable, Optional, Union @@ -119,7 +120,12 @@ from MDAnalysis.lib.distances import capped_distance import numpy as np -from solvation_analysis._utils import verify_solute_atoms, verify_solute_atoms_dict, get_closest_n_mol, get_radial_shell +from solvation_analysis._utils import ( + verify_solute_atoms, + verify_solute_atoms_dict, + get_closest_n_mol, + get_radial_shell, +) from solvation_analysis.rdf_parser import identify_cutoff_scipy from solvation_analysis.coordination import Coordination from solvation_analysis.networking import Networking @@ -127,7 +133,25 @@ from solvation_analysis.residence import Residence from solvation_analysis.speciation import Speciation -from solvation_analysis._column_names import * +from solvation_analysis._column_names import ( + FRAME, + DISTANCE, + SOLUTE, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + SOLVENT_ATOM_IX, + SOLUTE_ATOM_IX, +) + +try: + import rdkit + from rdkit.Chem import Draw + from rdkit.Chem import rdCoordGen + from rdkit.Chem.Draw.MolDrawing import DrawingOptions + +except ImportError: + rdkit = None class Solute(AnalysisBase): @@ -239,7 +263,7 @@ def __init__( self, solute_atoms: mda.AtomGroup, solvents: dict[str, mda.AtomGroup], - atom_solutes: Optional[dict[str, 'Solute']] = None, + atom_solutes: Optional[dict[str, "Solute"]] = None, radii: Optional[dict[str, float]] = None, rdf_kernel: Optional[Callable[[np.ndarray, np.ndarray], float]] = None, kernel_kwargs: Optional[dict[str, Any]] = None, @@ -257,8 +281,10 @@ def __init__( ``from_atoms``, ``from_atoms_dict`` or ``from_solute_list`` to create a Solute. """ if not internal_call: - raise RuntimeError("Please use Solute.from_atoms, Solute.from_atoms_dict, or " - "Solute.from_solute_list instead of the default constructor.") + raise RuntimeError( + "Please use Solute.from_atoms, Solute.from_atoms_dict, or " + "Solute.from_solute_list instead of the default constructor." + ) super(Solute, self).__init__(solute_atoms.universe.trajectory, verbose=verbose) self.solute_atoms = solute_atoms # TODO: this shit! @@ -266,7 +292,9 @@ def __init__( if self.atom_solutes is None or len(atom_solutes) <= 1: self.atom_solutes = {solute_name: self} self.radii = radii or {} - self.solvent_counts = {name: atoms.n_residues for name, atoms in solvents.items()} + self.solvent_counts = { + name: atoms.n_residues for name, atoms in solvents.items() + } self.kernel = rdf_kernel or identify_cutoff_scipy self.kernel_kwargs = kernel_kwargs or {} self.rdf_init_kwargs = rdf_init_kwargs or {} @@ -274,17 +302,19 @@ def __init__( self.has_run = False self.u = solute_atoms.universe self.n_solutes = solute_atoms.n_residues - self.solute_res_ix = pd.Series(solute_atoms.atoms.resindices, solute_atoms.atoms.ix) + self.solute_res_ix = pd.Series( + solute_atoms.atoms.resindices, solute_atoms.atoms.ix + ) self.solute_name = solute_name self.solvents = solvents if skip_rdf: - assert set(self.radii.keys()) >= set(self.solvents.keys()), ( - "To skip RDF generation, all solvent radii must be specified." - ) + assert set(self.radii.keys()) >= set( + self.solvents.keys() + ), "To skip RDF generation, all solvent radii must be specified." self.skip_rdf = skip_rdf # instantiate the res_name_map - self.res_name_map = pd.Series(['none'] * len(self.u.residues)) + self.res_name_map = pd.Series(["none"] * len(self.u.residues)) self.res_name_map[solute_atoms.residues.ix] = self.solute_name for name, solvent in solvents.items(): self.res_name_map[solvent.residues.ix] = name @@ -292,8 +322,14 @@ def __init__( # instantiate analysis classes. if analysis_classes is None: self.analysis_classes = ["pairing", "coordination", "speciation"] - elif analysis_classes == 'all': - self.analysis_classes = ["pairing", "coordination", "speciation", "residence", "networking"] + elif analysis_classes == "all": + self.analysis_classes = [ + "pairing", + "coordination", + "speciation", + "residence", + "networking", + ] else: self.analysis_classes = [cls.lower() for cls in analysis_classes] if "networking" in self.analysis_classes and networking_solvents is None: @@ -308,8 +344,8 @@ def from_atoms( solute_atoms: mda.AtomGroup, solvents: dict[str, mda.AtomGroup], rename_solutes: Optional[dict[str, str]] = None, - **kwargs: Any - ) -> 'Solute': + **kwargs: Any, + ) -> "Solute": """ Create a Solute from a single AtomGroup. The solute_atoms AtomGroup must should contain identical residues and identical atoms on each residue. @@ -349,14 +385,16 @@ def from_atoms( rename_solutes.get(f"solute_{i}") or f"solute_{i}": atom_group for i, atom_group in solute_atom_group_dict.items() } - return Solute.from_atoms_dict(solute_atom_group_dict_renamed, solvents, **kwargs) + return Solute.from_atoms_dict( + solute_atom_group_dict_renamed, solvents, **kwargs + ) @staticmethod def from_atoms_dict( solute_atoms_dict: dict[str, mda.AtomGroup], solvents: dict[str, mda.AtomGroup], - **kwargs: Any - ) -> 'Solute': + **kwargs: Any, + ) -> "Solute": """ Create a Solute object from a dictionary of solute atoms. @@ -378,7 +416,7 @@ def from_atoms_dict( """ # all solute AtomGroups in one AtomGroup + verification - assert isinstance(solute_atoms_dict, dict), ("Solute_atoms_dict must be a dict.") + assert isinstance(solute_atoms_dict, dict), "Solute_atoms_dict must be a dict." solute_atom_group = verify_solute_atoms_dict(solute_atoms_dict) # create the solutes for each atom @@ -388,7 +426,7 @@ def from_atoms_dict( atoms, solvents, internal_call=True, - **{**kwargs, "solute_name": solute_name} + **{**kwargs, "solute_name": solute_name}, ) # create the solute for the whole solute solute = Solute( @@ -396,18 +434,16 @@ def from_atoms_dict( solvents, atom_solutes=atom_solutes, internal_call=True, - **kwargs + **kwargs, ) if len(atom_solutes) > 1: - solute.run = solute._run_solute_atoms + solute.run = solute._run_solute_atoms return solute @staticmethod def from_solute_list( - solutes: list['Solute'], - solvents: dict[str, mda.AtomGroup], - **kwargs: Any - ) -> 'Solute': + solutes: list["Solute"], solvents: dict[str, mda.AtomGroup], **kwargs: Any + ) -> "Solute": """ Create a Solute from a list of Solutes. All Solutes must have only a single solute atom on each solute residue. Essentially, from_solute_list @@ -434,15 +470,17 @@ def from_solute_list( # check types and name uniqueness for solute in solutes: assert type(solute) == Solute, "solutes must be a list of Solute objects." - assert len(solute.solute_atoms.atoms) == len(solute.solute_atoms.atoms.residues), ( - "Each Solute in solutes must have only a single atom per residue." - ) + assert len(solute.solute_atoms.atoms) == len( + solute.solute_atoms.atoms.residues + ), "Each Solute in solutes must have only a single atom per residue." solute_names = [solute.solute_name for solute in solutes] - assert len(np.unique(solute_names)) == len(solute_names), ( - "The solute_name for each solute must be unique." - ) + assert len(np.unique(solute_names)) == len( + solute_names + ), "The solute_name for each solute must be unique." - solute_atom_group = reduce(lambda x, y: x | y, [solute.solute_atoms for solute in solutes]) + solute_atom_group = reduce( + lambda x, y: x | y, [solute.solute_atoms for solute in solutes] + ) verify_solute_atoms(solute_atom_group) atom_solutes = {solute.solute_name: solute for solute in solutes} @@ -451,18 +489,18 @@ def from_solute_list( solvents, atom_solutes=atom_solutes, internal_call=True, - **kwargs + **kwargs, ) if len(atom_solutes) > 1: - solute.run = solute._run_solute_atoms + solute.run = solute._run_solute_atoms return solute def _run_solute_atoms( - self, - start: Optional[int] = None, - stop: Optional[int] = None, - step: Optional[int] = None, - verbose: Optional[bool] = None + self, + start: Optional[int] = None, + stop: Optional[int] = None, + step: Optional[int] = None, + verbose: Optional[bool] = None, ): # like prepare atom_solutes = {} @@ -478,13 +516,17 @@ def _run_solute_atoms( if not solute.has_run: solute.run(start=start, stop=stop, step=step, verbose=verbose) if (start, stop, step) != (solute.start, solute.stop, solute.step): - warnings.warn(f"The start, stop, or step for {solute.solute_name} do not" - f"match the start, stop, or step for the run command so it " - f"is being re-run.") + warnings.warn( + f"The start, stop, or step for {solute.solute_name} do not" + f"match the start, stop, or step for the run command so it " + f"is being re-run." + ) solute.run(start=start, stop=stop, step=step, verbose=verbose) if self.solvents != solute.solvents: - warnings.warn(f"The solvents for {solute.solute_name} do not match the " - f"solvents for the run command so it is being re-run.") + warnings.warn( + f"The solvents for {solute.solute_name} do not match the " + f"solvents for the run command so it is being re-run." + ) solute.run(start=start, stop=stop, step=step, verbose=verbose) atom_solutes[solute.solute_name] = solute rdf_data[solute.solute_name] = solute.rdf_data[solute.solute_name] @@ -498,17 +540,25 @@ def _run_solute_atoms( # like conclude analysis_classes = { - 'speciation': Speciation, - 'pairing': Pairing, - 'coordination': Coordination, - 'residence': Residence, - 'networking': Networking, + "speciation": Speciation, + "pairing": Pairing, + "coordination": Coordination, + "residence": Residence, + "networking": Networking, } for analysis_class in self.analysis_classes: - if analysis_class == 'networking': - setattr(self, 'networking', Networking.from_solute(self, self.networking_solvents)) + if analysis_class == "networking": + setattr( + self, + "networking", + Networking.from_solute(self, self.networking_solvents), + ) else: - setattr(self, analysis_class, analysis_classes[analysis_class].from_solute(self)) + setattr( + self, + analysis_class, + analysis_classes[analysis_class].from_solute(self), + ) def _prepare(self): """ @@ -523,14 +573,22 @@ def _prepare(self): self.rdf_data = None break # set kwargs with defaults - self.rdf_init_kwargs["range"] = self.rdf_init_kwargs.get("range") or (0, 7.5) + self.rdf_init_kwargs["range"] = self.rdf_init_kwargs.get("range") or ( + 0, + 7.5, + ) self.rdf_init_kwargs["norm"] = self.rdf_init_kwargs.get("norm") or "density" self.rdf_run_kwargs["stop"] = self.rdf_run_kwargs.get("stop") or self.stop self.rdf_run_kwargs["step"] = self.rdf_run_kwargs.get("step") or self.step - self.rdf_run_kwargs["start"] = self.rdf_run_kwargs.get("start") or self.start + self.rdf_run_kwargs["start"] = ( + self.rdf_run_kwargs.get("start") or self.start + ) # generate and save RDFs rdf = InterRDF( - self.solute_atoms, solvent, **self.rdf_init_kwargs, exclude_same="residue" + self.solute_atoms, + solvent, + **self.rdf_init_kwargs, + exclude_same="residue", ) rdf.run(**self.rdf_run_kwargs) bins, data = rdf.results.bins, rdf.results.rdf @@ -538,10 +596,11 @@ def _prepare(self): # generate and save plots if name not in self.radii.keys(): self.radii[name] = self.kernel(bins, data, **self.kernel_kwargs) - calculated_radii = set([name for name, radius in self.radii.items() - if not np.isnan(radius)]) + calculated_radii = set( + [name for name, radius in self.radii.items() if not np.isnan(radius)] + ) missing_solvents = set(self.solvents.keys()) - calculated_radii - missing_solvents_str = ' '.join([str(i) for i in missing_solvents]) + missing_solvents_str = " ".join([str(i) for i in missing_solvents]) assert len(missing_solvents) == 0, ( f"Solute could not identify a solvation radius for " f"{missing_solvents_str}. Please manually enter missing radii " @@ -565,7 +624,10 @@ def _single_frame(self): box=self.u.dimensions, ) # make sure pairs don't include intra-molecular interactions - filter = self.solute_atoms.resindices[pairs[:, 0]] == solvent.resindices[pairs[:, 1]] + filter = ( + self.solute_atoms.resindices[pairs[:, 0]] + == solvent.resindices[pairs[:, 1]] + ) pairs = pairs[~filter] dist = dist[~filter] # replace local ids with absolute ids @@ -593,7 +655,7 @@ def _single_frame(self): dist_array, solute_res_name_array, solvent_res_name_array, - solvent_res_ix_array + solvent_res_ix_array, ) ) # add the current frame to the growing list of solvation arrays @@ -615,38 +677,59 @@ def _conclude(self): DISTANCE, SOLUTE, SOLVENT, - SOLVENT_IX - ] + SOLVENT_IX, + ], ) # clean up solvation_data df - for column in [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX, DISTANCE, SOLVENT_IX]: + for column in [ + FRAME, + SOLUTE_IX, + SOLUTE_ATOM_IX, + SOLVENT_ATOM_IX, + DISTANCE, + SOLVENT_IX, + ]: solvation_data_df[column] = pd.to_numeric(solvation_data_df[column]) - solvation_data_df = solvation_data_df.sort_values([FRAME, SOLUTE_ATOM_IX, DISTANCE]) - solvation_data_duplicates = solvation_data_df.duplicated(subset=[FRAME, SOLUTE_ATOM_IX, SOLVENT_IX]) + solvation_data_df = solvation_data_df.sort_values( + [FRAME, SOLUTE_ATOM_IX, DISTANCE] + ) + solvation_data_duplicates = solvation_data_df.duplicated( + subset=[FRAME, SOLUTE_ATOM_IX, SOLVENT_IX] + ) solvation_data = solvation_data_df[~solvation_data_duplicates] - self.solvation_data = solvation_data.set_index([FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX]) + self.solvation_data = solvation_data.set_index( + [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX] + ) duplicates = solvation_data_df[solvation_data_duplicates] - self.solvation_data_duplicates = duplicates.set_index([FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX]) + self.solvation_data_duplicates = duplicates.set_index( + [FRAME, SOLUTE_IX, SOLUTE_ATOM_IX, SOLVENT_ATOM_IX] + ) # instantiate analysis classes self.has_run = True analysis_classes = { - 'speciation': Speciation, - 'pairing': Pairing, - 'coordination': Coordination, - 'residence': Residence, - 'networking': Networking, + "speciation": Speciation, + "pairing": Pairing, + "coordination": Coordination, + "residence": Residence, + "networking": Networking, } for analysis_class in self.analysis_classes: - if analysis_class == 'networking': - setattr(self, 'networking', Networking.from_solute(self, self.networking_solvents)) + if analysis_class == "networking": + setattr( + self, + "networking", + Networking.from_solute(self, self.networking_solvents), + ) else: - setattr(self, analysis_class, analysis_classes[analysis_class].from_solute(self)) + setattr( + self, + analysis_class, + analysis_classes[analysis_class].from_solute(self), + ) @staticmethod def _plot_solvation_radius( - bins: np.ndarray, - data: np.ndarray, - radius: float + bins: np.ndarray, data: np.ndarray, radius: float ) -> tuple[plt.Figure, plt.Axes]: """ Plot a solvation radius on an RDF. @@ -676,9 +759,7 @@ def _plot_solvation_radius( return fig, ax def plot_solvation_radius( - self, - solute_name: str, - solvent_name: str + self, solute_name: str, solvent_name: str ) -> tuple[plt.Figure, plt.Axes]: """ Plot the RDF of a solvent molecule @@ -707,10 +788,10 @@ def plot_solvation_radius( return fig, ax def draw_molecule( - self, - residue: Union[str, mda.core.groups.Residue], - filename: Optional[str] = None - ) -> 'rdkit.Chem.rdchem.Mol': + self, + residue: Union[str, mda.core.groups.Residue], + filename: Optional[str] = None, + ) -> "rdkit.Chem.rdchem.Mol": """ Returns @@ -724,17 +805,22 @@ def draw_molecule( RDKit.Chem.rdchem.Mol """ - from rdkit.Chem import Draw - from rdkit.Chem import rdCoordGen - from rdkit.Chem.Draw.MolDrawing import DrawingOptions + if rdkit is None: + raise ImportError( + "The RDKit package is required to use this function. " + "Please install RDKit with `conda install -c conda-forge rdkit`." + ) + DrawingOptions.atomLabelFontSize = 100 if isinstance(residue, str): if residue in [self.solute_name, "solute"]: mol = self.solute_atoms.residues[0].atoms.convert_to("RDKIT") mol_mda_ix = self.solute_atoms.residues[0].atoms.ix - solute_atoms_ix0 = {solute.solute_atoms.atoms.ix[0]: solute_name - for solute_name, solute in self.atom_solutes.items()} + solute_atoms_ix0 = { + solute.solute_atoms.atoms.ix[0]: solute_name + for solute_name, solute in self.atom_solutes.items() + } for i, atom in enumerate(mol.GetAtoms()): atom_name = solute_atoms_ix0.get(mol_mda_ix[i]) label = f"{i}, " + atom_name if atom_name else str(i) @@ -744,8 +830,10 @@ def draw_molecule( for i, atom in enumerate(mol.GetAtoms()): atom.SetProp("atomNote", str(i)) else: - raise ValueError("If the residue is a string, it must be the name of a solute, " - "the name of a solvent, or 'solute'.") + raise ValueError( + "If the residue is a string, it must be the name of a solute, " + "the name of a solvent, or 'solute'." + ) else: assert isinstance(residue, mda.core.groups.Residue) mol = residue.atoms.convert_to("RDKIT") @@ -757,12 +845,12 @@ def draw_molecule( return mol def get_shell( - self, - solute_index: int, - frame: int, - as_df: bool = False, - remove_mols: Optional[dict[str, int]] = None, - closest_n_only: Optional[int] = None + self, + solute_index: int, + frame: int, + as_df: bool = False, + remove_mols: Optional[dict[str, int]] = None, + closest_n_only: Optional[int] = None, ) -> Union[mda.AtomGroup, pd.DataFrame]: """ Select the solvation shell of the solute. @@ -800,8 +888,9 @@ def get_shell( """ assert self.has_run, "Solute.run() must be called first." - assert frame in self.frames, ("The requested frame must be one " - "of an analyzed frames in self.frames.") + assert frame in self.frames, ( + "The requested frame must be one " "of an analyzed frames in self.frames." + ) remove_mols = {} if remove_mols is None else remove_mols # select shell of interest shell = self.solvation_data.xs((frame, solute_index), level=(FRAME, SOLUTE_IX)) @@ -813,7 +902,7 @@ def get_shell( mol_count = len(res_ix) n_remove = min(mol_count, n_remove) # then truncate resnames to remove mols - remove_ix = res_ix[(mol_count - n_remove):] + remove_ix = res_ix[(mol_count - n_remove) :] # then apply to original shell remove = shell[SOLVENT_IX].isin(remove_ix) shell = shell[np.invert(remove)] @@ -821,20 +910,24 @@ def get_shell( if closest_n_only: assert closest_n_only > 0, "closest_n_only must be at least 1" closest_n_only = min(len(shell), closest_n_only) - shell = shell[0: closest_n_only] + shell = shell[0:closest_n_only] if as_df: return shell else: return self._df_to_atom_group(shell, solute_index=solute_index) def get_closest_n_mol( - self, - solute_atom_ix: int, - n_mol: int, - guess_radius: Union[float, int] = 3, - return_ordered_resix: bool = False, - return_radii: bool = False, - ) -> Union[mda.AtomGroup, tuple[mda.AtomGroup, np.ndarray], tuple[mda.AtomGroup, np.ndarray, np.ndarray]]: + self, + solute_atom_ix: int, + n_mol: int, + guess_radius: Union[float, int] = 3, + return_ordered_resix: bool = False, + return_radii: bool = False, + ) -> Union[ + mda.AtomGroup, + tuple[mda.AtomGroup, np.ndarray], + tuple[mda.AtomGroup, np.ndarray, np.ndarray], + ]: """ Select the n closest mols to the solute. @@ -876,9 +969,7 @@ def get_closest_n_mol( ) def radial_shell( - self, - solute_atom_ix: int, - radius: Union[float, int] + self, solute_atom_ix: int, radius: Union[float, int] ) -> mda.AtomGroup: """ Select all residues with atoms within r of the solute. @@ -901,9 +992,7 @@ def radial_shell( return get_radial_shell(self.solute_atoms[solute_atom_ix], radius) def _df_to_atom_group( - self, - df: pd.DataFrame, - solute_index: Optional[int] = None + self, df: pd.DataFrame, solute_index: Optional[int] = None ) -> mda.AtomGroup: """ Selects an MDAnalysis.AtomGroup from a pandas.DataFrame with solvent. diff --git a/solvation_analysis/speciation.py b/solvation_analysis/speciation.py index fe7288a..6ce3dbe 100644 --- a/solvation_analysis/speciation.py +++ b/solvation_analysis/speciation.py @@ -21,7 +21,14 @@ import pandas as pd -from solvation_analysis._column_names import * +import solvation_analysis +from solvation_analysis._column_names import ( + FRAME, + SOLUTE_IX, + SOLVENT, + SOLVENT_IX, + COUNT, +) class Speciation: @@ -51,7 +58,9 @@ class Speciation: The number of solutes in solvation_data. """ - def __init__(self, solvation_data: pd.DataFrame, n_frames: int, n_solutes: int) -> None: + def __init__( + self, solvation_data: pd.DataFrame, n_frames: int, n_solutes: int + ) -> None: self.solvation_data = solvation_data self.n_frames = n_frames self.n_solutes = n_solutes @@ -59,7 +68,7 @@ def __init__(self, solvation_data: pd.DataFrame, n_frames: int, n_solutes: int) self._solvent_co_occurrence = self._solvent_co_occurrence() @staticmethod - def from_solute(solute: 'Solute') -> 'Speciation': + def from_solute(solute: "solvation_analysis.Solute") -> "Speciation": """ Generate a Speciation object from a solute. @@ -79,7 +88,9 @@ def from_solute(solute: 'Solute') -> 'Speciation': ) def _compute_speciation(self) -> tuple[pd.DataFrame, pd.DataFrame]: - counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX] + counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[ + SOLVENT_IX + ] counts_re = counts.reset_index([SOLVENT]) speciation_data = counts_re.pivot(columns=[SOLVENT]).fillna(0).astype(int) res_names = speciation_data.columns.levels[1] @@ -87,11 +98,15 @@ def _compute_speciation(self) -> tuple[pd.DataFrame, pd.DataFrame]: sum_series = speciation_data.groupby(speciation_data.columns.to_list()).size() sum_sorted = sum_series.sort_values(ascending=False) speciation_fraction = sum_sorted.reset_index().rename(columns={0: COUNT}) - speciation_fraction[COUNT] = speciation_fraction[COUNT] / (self.n_frames * self.n_solutes) + speciation_fraction[COUNT] = speciation_fraction[COUNT] / ( + self.n_frames * self.n_solutes + ) return speciation_data, speciation_fraction @classmethod - def _mean_speciation(cls, speciation_frames: pd.DataFrame, solute_number: int, frame_number: int) -> pd.Series: + def _mean_speciation( + cls, speciation_frames: pd.DataFrame, solute_number: int, frame_number: int + ) -> pd.Series: means = speciation_frames.sum(axis=1) / (solute_number * frame_number) return means @@ -167,11 +182,13 @@ def _solvent_co_occurrence(self) -> pd.DataFrame: actual_solvents_list = [] for solvent in self.speciation_data.columns.values: # calculate number of available coordinating solvent slots - shells_w_solvent = self.speciation_data.query(f'`{solvent}` > 0') + shells_w_solvent = self.speciation_data.query(f"`{solvent}` > 0") n_solvents = shells_w_solvent.sum() # calculate expected number of coordinating solvents n_coordination_slots = n_solvents.sum() - len(shells_w_solvent) - coordination_fraction = self.speciation_data.sum() / self.speciation_data.sum().sum() + coordination_fraction = ( + self.speciation_data.sum() / self.speciation_data.sum().sum() + ) expected_solvents = coordination_fraction * n_coordination_slots # calculate actual number of coordinating solvents actual_solvents = n_solvents.copy()