Skip to content

Commit

Permalink
Add pre-commit hook and lint all files.
Browse files Browse the repository at this point in the history
  • Loading branch information
orionarcher committed May 7, 2024
1 parent d550b02 commit bff5f2d
Show file tree
Hide file tree
Showing 14 changed files with 615 additions and 337 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:
cache-dependency-path: pyproject.toml

- uses: pre-commit/action@v3.0.0
with:
extra_args: --files solvation_analysis/*

test:
name: Test on ${{ matrix.os }}, Python ${{ matrix.python-version }}
Expand Down Expand Up @@ -91,7 +93,7 @@ jobs:
file: ./coverage.xml
flags: unittests
name: codecov-${{ matrix.os }}-py${{ matrix.python-version }}

test_pip_install:
name: pip (PEP517) install on ${{ matrix.os }}, Python ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
Expand Down
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ repos:
args: [--remove]
- id: end-of-file-fixer
- id: trailing-whitespace
#- repo: https://github.com/asottile/blacken-docs
# rev: 1.16.0
# hooks:
# - id: blacken-docs
# additional_dependencies: [black]
# exclude: README.md
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
Expand Down
12 changes: 8 additions & 4 deletions solvation_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@
SolvationAnalysis
An MDAnalysis rmodule for solvation analysis.
"""

from . import _version
from solvation_analysis.solute import Solute

# Handle versioneer
from ._version import get_versions

versions = get_versions()
__version__ = versions['version']
__git_revision__ = versions['full-revisionid']
__version__ = versions["version"]
__git_revision__ = versions["full-revisionid"]
del get_versions, versions

from . import _version
__version__ = _version.get_versions()['version']

__version__ = _version.get_versions()["version"]
__all__ = ["Solute"]
1 change: 0 additions & 1 deletion solvation_analysis/_column_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
1. change the variable name in all files
"""


# for solvation_data
FRAME = "frame"
SOLUTE_IX = "solute_ix"
Expand Down
82 changes: 53 additions & 29 deletions solvation_analysis/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import MDAnalysis as mda
from MDAnalysis.analysis import distances

from solvation_analysis._column_names import *
from solvation_analysis._column_names import FRAME, SOLUTE_IX, SOLVENT_IX, DISTANCE


def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomGroup]:
Expand All @@ -16,63 +16,74 @@ def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomG
# and that the residues are all the same length
# then this should work
all_res_len = np.array([res.atoms.n_atoms for res in solute_atom_group.residues])
assert np.all(all_res_len[0] == all_res_len), (
"All residues must be the same length."
)
assert np.all(
all_res_len[0] == all_res_len
), "All residues must be the same length."
res_atom_local_ix = defaultdict(list)
res_atom_ix = defaultdict(list)

for atom in solute_atom_group.atoms:
res_atom_local_ix[atom.resindex].append(atom.ix - atom.residue.atoms[0].ix)
res_atom_ix[atom.resindex].append(atom.index)
res_occupancy = np.array([len(ix) for ix in res_atom_local_ix.values()])
assert np.all(res_occupancy[0] == res_occupancy), (
"All residues must have the same number of solute_atoms atoms on them."
)
assert np.all(
res_occupancy[0] == res_occupancy
), "All residues must have the same number of solute_atoms atoms on them."

res_atom_array = np.array(list(res_atom_local_ix.values()))
assert np.all(res_atom_array[0] == res_atom_array), (
"All residues must have the same solute_atoms atoms on them."
)
assert np.all(
res_atom_array[0] == res_atom_array
), "All residues must have the same solute_atoms atoms on them."

res_atom_ix_array = np.array(list(res_atom_ix.values()))
solute_atom_group_dict = {}
for i in range(0, res_atom_ix_array.shape[1]):
solute_atom_group_dict[i] = solute_atom_group.universe.atoms[res_atom_ix_array[:, i]]
solute_atom_group_dict[i] = solute_atom_group.universe.atoms[
res_atom_ix_array[:, i]
]
return solute_atom_group_dict


def verify_solute_atoms_dict(solute_atoms_dict: dict[str, mda.AtomGroup]) -> mda.AtomGroup:
def verify_solute_atoms_dict(
solute_atoms_dict: dict[str, mda.AtomGroup],
) -> mda.AtomGroup:
# first we verify the input format
atom_group_lengths = []
for solute_name, solute_atom_group in solute_atoms_dict.items():
assert isinstance(solute_name, str), (
"The keys of solutes_dict must be strings."
)
assert isinstance(solute_name, str), "The keys of solutes_dict must be strings."
assert isinstance(solute_atom_group, mda.AtomGroup), (
f"The values of solutes_dict must be MDAnalysis.AtomGroups. But the value"
f"for {solute_name} is a {type(solute_atom_group)}."
)
assert len(solute_atom_group) == len(solute_atom_group.residues), (
"The solute_atom_group must have a single atom on each residue."
)
assert len(solute_atom_group) == len(
solute_atom_group.residues
), "The solute_atom_group must have a single atom on each residue."
atom_group_lengths.append(len(solute_atom_group))
assert np.all(np.array(atom_group_lengths) == atom_group_lengths[0]), (
"AtomGroups in solutes_dict must have the same length because there should be"
"one atom per solute residue."
)

# verify that the solute_atom_groups have no overlap
solute_atom_group = reduce(lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()])
assert solute_atom_group.n_atoms == sum([atoms.n_atoms for atoms in solute_atoms_dict.values()]), (
"The solute_atom_groups must not overlap."
solute_atom_group = reduce(
lambda x, y: x | y, [atoms for atoms in solute_atoms_dict.values()]
)
assert solute_atom_group.n_atoms == sum(
[atoms.n_atoms for atoms in solute_atoms_dict.values()]
), "The solute_atom_groups must not overlap."
verify_solute_atoms(solute_atom_group)

return solute_atom_group


def get_atom_group(selection: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup]) -> mda.AtomGroup:
def get_atom_group(
selection: Union[
mda.core.groups.Residue,
mda.core.groups.ResidueGroup,
mda.core.groups.Atom,
mda.core.groups.AtomGroup,
],
) -> mda.AtomGroup:
"""
Cast an MDAnalysis.Atom, MDAnalysis.Residue, or MDAnalysis.ResidueGroup to AtomGroup.
Expand Down Expand Up @@ -103,14 +114,22 @@ def get_atom_group(selection: Union[mda.core.groups.Residue, mda.core.groups.Res
return selection



def get_closest_n_mol(
central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup],
central_species: Union[
mda.core.groups.Residue,
mda.core.groups.ResidueGroup,
mda.core.groups.Atom,
mda.core.groups.AtomGroup,
],
n_mol: int,
guess_radius: Union[float, int] = 3,
return_ordered_resix: bool = False,
return_radii: bool = False,
) -> Union[mda.AtomGroup, tuple[mda.AtomGroup, np.ndarray], tuple[mda.AtomGroup, np.ndarray, np.ndarray]]:
) -> Union[
mda.AtomGroup,
tuple[mda.AtomGroup, np.ndarray],
tuple[mda.AtomGroup, np.ndarray, np.ndarray],
]:
"""
Returns the closest n molecules to the central species, an array of their resix,
and an array of the distance of the closest atom in each molecule.
Expand Down Expand Up @@ -152,15 +171,15 @@ def get_closest_n_mol(
n_mol,
guess_radius + 1,
return_ordered_resix=return_ordered_resix,
return_radii=return_radii
return_radii=return_radii,
)
radii = distances.distance_array(coords, partial_shell.positions, box=u.dimensions)[
0
]
ordering = np.argsort(radii)
ordered_resix = shell_resix[ordering]
closest_n_resix = np.sort(np.unique(ordered_resix, return_index=True)[1])[
0: n_mol + 1
0 : n_mol + 1
]
str_resix = " ".join(str(resix) for resix in ordered_resix[closest_n_resix])
full_shell = u.select_atoms(f"resindex {str_resix}")
Expand All @@ -179,8 +198,13 @@ def get_closest_n_mol(


def get_radial_shell(
central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup],
radius: Union[float, int]
central_species: Union[
mda.core.groups.Residue,
mda.core.groups.ResidueGroup,
mda.core.groups.Atom,
mda.core.groups.AtomGroup,
],
radius: Union[float, int],
) -> mda.AtomGroup:
"""
Returns all molecules with atoms within the radius of the central species.
Expand Down
Loading

0 comments on commit bff5f2d

Please sign in to comment.