Skip to content

Commit

Permalink
Add type hints to all files.
Browse files Browse the repository at this point in the history
  • Loading branch information
orionarcher committed May 7, 2024
1 parent 15fb1a8 commit 7fe6f7e
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 138 deletions.
1 change: 1 addition & 0 deletions solvation_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
SolvationAnalysis
An MDAnalysis rmodule for solvation analysis.
"""
from solvation_analysis.solute import Solute

# Handle versioneer
from ._version import get_versions
Expand Down
31 changes: 19 additions & 12 deletions solvation_analysis/_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import numpy as np
from collections import defaultdict
from functools import reduce
from typing import Union

import numpy as np
import pandas as pd
import MDAnalysis as mda
from MDAnalysis.analysis import distances

from solvation_analysis._column_names import *


def verify_solute_atoms(solute_atom_group):
def verify_solute_atoms(solute_atom_group: mda.AtomGroup) -> dict[int, mda.AtomGroup]:
# we presume that the solute_atoms has the same number of atoms on each residue
# and that they all have the same indices on those residues
# and that the residues are all the same length
Expand Down Expand Up @@ -39,7 +42,7 @@ def verify_solute_atoms(solute_atom_group):
return solute_atom_group_dict


def verify_solute_atoms_dict(solute_atoms_dict):
def verify_solute_atoms_dict(solute_atoms_dict: dict[str, mda.AtomGroup]) -> mda.AtomGroup:
# first we verify the input format
atom_group_lengths = []
for solute_name, solute_atom_group in solute_atoms_dict.items():
Expand Down Expand Up @@ -69,7 +72,7 @@ def verify_solute_atoms_dict(solute_atoms_dict):
return solute_atom_group


def get_atom_group(selection):
def get_atom_group(selection: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup]) -> mda.AtomGroup:
"""
Cast an MDAnalysis.Atom, MDAnalysis.Residue, or MDAnalysis.ResidueGroup to AtomGroup.
Expand Down Expand Up @@ -100,13 +103,14 @@ def get_atom_group(selection):
return selection



def get_closest_n_mol(
central_species,
n_mol,
guess_radius=3,
return_ordered_resix=False,
return_radii=False,
):
central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup],
n_mol: int,
guess_radius: Union[float, int] = 3,
return_ordered_resix: bool = False,
return_radii: bool = False,
) -> Union[mda.AtomGroup, tuple[mda.AtomGroup, np.ndarray], tuple[mda.AtomGroup, np.ndarray, np.ndarray]]:
"""
Returns the closest n molecules to the central species, an array of their resix,
and an array of the distance of the closest atom in each molecule.
Expand Down Expand Up @@ -174,7 +178,10 @@ def get_closest_n_mol(
return full_shell


def get_radial_shell(central_species, radius):
def get_radial_shell(
central_species: Union[mda.core.groups.Residue, mda.core.groups.ResidueGroup, mda.core.groups.Atom, mda.core.groups.AtomGroup],
radius: Union[float, int]
) -> mda.AtomGroup:
"""
Returns all molecules with atoms within the radius of the central species.
(specifically, within the radius of the COM of central species).
Expand All @@ -199,7 +206,7 @@ def get_radial_shell(central_species, radius):
return full_shell


def calculate_adjacency_dataframe(solvation_data):
def calculate_adjacency_dataframe(solvation_data: pd.DataFrame) -> pd.DataFrame:
"""
Calculate a frame-by-frame adjacency matrix from the solvation data.
Expand Down
26 changes: 17 additions & 9 deletions solvation_analysis/coordination.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import pandas as pd
import MDAnalysis as mda

from solvation_analysis._column_names import *

Expand Down Expand Up @@ -62,7 +63,14 @@ class Coordination:
"""

def __init__(self, solvation_data, n_frames, n_solutes, solvent_counts, atom_group):
def __init__(
self,
solvation_data: pd.DataFrame,
n_frames: int,
n_solutes: int,
solvent_counts: dict[str, int],
atom_group: mda.core.groups.AtomGroup
) -> None:
self.solvation_data = solvation_data
self.n_frames = n_frames
self.n_solutes = n_solutes
Expand All @@ -73,7 +81,7 @@ def __init__(self, solvation_data, n_frames, n_solutes, solvent_counts, atom_gro
self._coordination_vs_random = self._calculate_coordination_vs_random()

@staticmethod
def from_solute(solute):
def from_solute(solute: 'Solute') -> 'Coordination':
"""
Generate a Coordination object from a solute.
Expand All @@ -94,7 +102,7 @@ def from_solute(solute):
solute.u.atoms,
)

def _mean_cn(self):
def _mean_cn(self) -> tuple[dict[str, float], pd.DataFrame]:
counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX]
cn_series = counts.groupby([SOLVENT, FRAME]).sum() / (
self.n_solutes * self.n_frames
Expand All @@ -103,7 +111,7 @@ def _mean_cn(self):
cn_dict = cn_series.groupby([SOLVENT]).sum().to_dict()
return cn_dict, cn_by_frame

def _calculate_coordinating_atoms(self, tol=0.005):
def _calculate_coordinating_atoms(self, tol: float = 0.005) -> pd.DataFrame:
"""
Determine which atom types are actually coordinating
return the types of those atoms
Expand Down Expand Up @@ -131,7 +139,7 @@ def _calculate_coordinating_atoms(self, tol=0.005):
)
return type_fractions[type_fractions[FRACTION] > tol]

def _calculate_coordination_vs_random(self):
def _calculate_coordination_vs_random(self) -> dict[str, float]:
"""
Calculate the coordination number relative to random coordination.
Expand All @@ -150,29 +158,29 @@ def _calculate_coordination_vs_random(self):
return coordination_vs_random

@property
def coordination_numbers(self):
def coordination_numbers(self) -> dict[str, float]:
"""
A dictionary where keys are residue names (str) and values are the
mean coordination number of that residue (float).
"""
return self._cn_dict

@property
def coordination_numbers_by_frame(self):
def coordination_numbers_by_frame(self) -> pd.DataFrame:
"""
A DataFrame of the mean coordination number of in each frame of the trajectory.
"""
return self._cn_dict_by_frame

@property
def coordinating_atoms(self):
def coordinating_atoms(self) -> pd.DataFrame:
"""
Fraction of each atom_type participating in solvation, calculated for each solvent.
"""
return self._coordinating_atoms

@property
def coordination_vs_random(self):
def coordination_vs_random(self) -> dict[str, float]:
"""
Coordination number relative to random coordination.
Expand Down
30 changes: 19 additions & 11 deletions solvation_analysis/networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
as an attribute of the Solute class. This makes instantiating it and calculating the
solvation data a non-issue.
"""
from typing import Union

import pandas as pd
import numpy as np
from scipy.sparse import csr_matrix
Expand Down Expand Up @@ -61,7 +63,13 @@ class Networking:
>>> networking = Networking.from_solute(solute, 'PF6')
"""

def __init__(self, solvents, solvation_data, solute_res_ix, res_name_map):
def __init__(
self,
solvents: Union[str, list[str]],
solvation_data: pd.DataFrame,
solute_res_ix: np.ndarray,
res_name_map: pd.Series
) -> None:
self.solvents = solvents
self.solvation_data = solvation_data
solvent_present = np.isin(self.solvents, self.solvation_data[SOLVENT].unique())
Expand All @@ -77,7 +85,7 @@ def __init__(self, solvents, solvation_data, solute_res_ix, res_name_map):
self._solute_status = self._solute_status.to_dict()

@staticmethod
def from_solute(solute, solvents):
def from_solute(solute: 'Solute', solvents: Union[str, list[str]]) -> 'Networking':
"""
Generate a Networking object from a solute and solvent names.
Expand All @@ -102,7 +110,7 @@ def from_solute(solute, solvents):
)

@staticmethod
def _unwrap_adjacency_dataframe(df):
def _unwrap_adjacency_dataframe(df: pd.DataFrame) -> csr_matrix:
# this class will transform the biadjacency matrix into a proper adjacency matrix
connections = df.reset_index(FRAME).drop(columns=FRAME)
idx = connections.columns.append(connections.index)
Expand All @@ -111,7 +119,7 @@ def _unwrap_adjacency_dataframe(df):
adjacency_matrix = csr_matrix(undirected)
return adjacency_matrix

def _generate_networks(self):
def _generate_networks(self) -> pd.DataFrame:
"""
This function generates a dataframe containing all the solute-solvent networks
in every frame of the simulation. The rough approach is as follows:
Expand Down Expand Up @@ -158,15 +166,15 @@ def _generate_networks(self):
)
return cluster_df

def _calculate_network_sizes(self):
def _calculate_network_sizes(self) -> pd.DataFrame:
# This utility calculates the network sizes and returns a convenient dataframe.
cluster_df = self.network_df
cluster_sizes = cluster_df.groupby([FRAME, NETWORK]).count()
size_counts = cluster_sizes.groupby([FRAME, SOLVENT]).count().unstack(fill_value=0)
size_counts.columns = size_counts.columns.droplevel(None) # the column value is None
return size_counts

def _calculate_solute_status(self):
def _calculate_solute_status(self) -> tuple[pd.Series, pd.DataFrame]:
"""
This utility calculates the fraction of each solute with a given "status".
Namely, whether the solvent is "isolated", "paired" (with a single solvent), or
Expand All @@ -182,7 +190,7 @@ def _calculate_solute_status(self):
solute_status = solute_status_by_frame.mean()
return solute_status, solute_status_by_frame

def get_network_res_ix(self, network_index, frame):
def get_network_res_ix(self, network_index: int, frame: int) -> np.ndarray:
"""
Return the indexes of all residues in a selected network.
Expand Down Expand Up @@ -217,15 +225,15 @@ def get_network_res_ix(self, network_index, frame):
return res_ix.astype(int)

@property
def network_df(self):
def network_df(self) -> pd.DataFrame:
"""
The dataframe containing all networking data. the indices are the frame and
network index, respectively. the columns are the solvent_name and res_ix.
"""
return self._network_df

@property
def network_sizes(self):
def network_sizes(self) -> pd.DataFrame:
"""
A dataframe of network sizes. the index is the frame. the column headers
are network sizes, or the number of solutes + solvents in the network, so
Expand All @@ -235,7 +243,7 @@ def network_sizes(self):
return self._network_sizes

@property
def solute_status(self):
def solute_status(self) -> dict[str, float]:
"""
A dictionary where the keys are the "status" of the solute and the values
are the fraction of solute with that status, averaged over all frames.
Expand All @@ -250,7 +258,7 @@ def solute_status(self):
return self._solute_status

@property
def solute_status_by_frame(self):
def solute_status_by_frame(self) -> pd.DataFrame:
"""
As described above, except organized into a dataframe where each
row is a unique frame and the columns are "isolated", "paired", and "networked".
Expand Down
28 changes: 17 additions & 11 deletions solvation_analysis/pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ class Pairing:
{'BN': 1.0, 'FEC': 0.210, 'PF6': 0.120}
"""

def __init__(self, solvation_data, n_frames, n_solutes, n_solvents):
def __init__(
self,
solvation_data: pd.DataFrame,
n_frames: int,
n_solutes: int,
n_solvents: dict[str, int]
) -> None:
self.solvation_data = solvation_data
self.n_frames = n_frames
self.n_solutes = n_solutes
Expand All @@ -66,7 +72,7 @@ def __init__(self, solvation_data, n_frames, n_solutes, n_solvents):
self._diluent_composition, self._diluent_composition_by_frame, self._diluent_counts = self._diluent_composition()

@staticmethod
def from_solute(solute):
def from_solute(solute: 'Solute') -> 'Pairing':
"""
Generate a Pairing object from a solute.
Expand All @@ -86,7 +92,7 @@ def from_solute(solute):
solute.solvent_counts
)

def _fraction_coordinated(self):
def _fraction_coordinated(self) -> tuple[dict[str, float], pd.DataFrame]:
# calculate the fraction of solute coordinated with each solvent
counts = self.solvation_data.groupby([FRAME, SOLUTE_IX, SOLVENT]).count()[SOLVENT_IX]
pairing_series = counts.astype(bool).groupby([SOLVENT, FRAME]).sum() / (
Expand All @@ -97,15 +103,15 @@ def _fraction_coordinated(self):
pairing_dict = pairing_normalized.groupby([SOLVENT]).sum().to_dict()
return pairing_dict, pairing_by_frame

def _fraction_free_solvent(self):
def _fraction_free_solvent(self) -> dict[str, float]:
# calculate the fraction of each solvent NOT coordinated with the solute
counts = self.solvation_data.groupby([FRAME, SOLVENT_IX, SOLVENT]).count()[DISTANCE]
totals = counts.groupby([SOLVENT]).count() / self.n_frames
n_solvents = np.array([self.solvent_counts[name] for name in totals.index.values])
free_solvents = np.ones(len(totals)) - totals / n_solvents
return free_solvents.to_dict()

def _diluent_composition(self):
def _diluent_composition(self) -> tuple[dict[str, float], pd.DataFrame, pd.DataFrame]:
coordinated_solvents = self.solvation_data.groupby([FRAME, SOLVENT]).nunique()[SOLVENT_IX]
solvent_counts = pd.Series(self.solvent_counts)
total_solvents = solvent_counts.reindex(coordinated_solvents.index, level=1)
Expand All @@ -117,45 +123,45 @@ def _diluent_composition(self):
return diluent_dict, diluent_by_frame, diluent_counts

@property
def solvent_pairing(self):
def solvent_pairing(self) -> dict[str, float]:
"""
A dictionary where keys are residue names (str) and values are the
fraction of solutes that contain that residue (float).
"""
return self._solvent_pairing

@property
def pairing_by_frame(self):
def pairing_by_frame(self) -> pd.DataFrame:
"""
A pd.Dataframe tracking the mean fraction of each residue across frames.
"""
return self._pairing_by_frame

@property
def fraction_free_solvents(self):
def fraction_free_solvents(self) -> dict[str, float]:
"""
A dictionary containing the fraction of each solvent that is free. e.g.
not coordinated to a solute.
"""
return self._fraction_free_solvents

@property
def diluent_composition(self):
def diluent_composition(self) -> dict[str, float]:
"""
The fraction of the diluent constituted by each solvent. The diluent is
defined as everything that is not coordinated with the solute.
"""
return self._diluent_composition

@property
def diluent_composition_by_frame(self):
def diluent_composition_by_frame(self) -> pd.DataFrame:
"""
A DataFrame of the diluent composition in each frame of the trajectory.
"""
return self._diluent_composition_by_frame

@property
def diluent_counts(self):
def diluent_counts(self) -> pd.DataFrame:
"""
A DataFrame of the raw solvent counts in the diluent in each frame of the trajectory.
"""
Expand Down
Loading

0 comments on commit 7fe6f7e

Please sign in to comment.