Skip to content

Commit

Permalink
Add uncertainties to thermo target (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Nov 16, 2023
1 parent 032865c commit 65aaf50
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 66 deletions.
132 changes: 92 additions & 40 deletions descent/targets/thermo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pyarrow
import pydantic
import smee.mm
import smee.utils
import torch
from rdkit import Chem

Expand Down Expand Up @@ -108,6 +109,10 @@ class SimulationConfig(pydantic.BaseModel):
..., description="Configuration for generating initial coordinates."
)

apply_hmr: bool = pydantic.Field(
False, description="Whether to apply hydrogen mass repartitioning."
)

equilibrate: list[
smee.mm.MinimizationConfig | smee.mm.SimulationConfig
] = pydantic.Field(..., description="Configuration for equilibration simulations.")
Expand All @@ -120,6 +125,15 @@ class SimulationConfig(pydantic.BaseModel):
)


class _Observables(typing.NamedTuple):
"""Ensemble averages of the observables computed from a simulation."""

mean: dict[str, torch.Tensor]
"""The mean value of each observable with ``shape=()``."""
std: dict[str, torch.Tensor]
"""The standard deviation of each observable with ``shape=()``."""


_SystemDict = dict[SimulationKey, smee.TensorSystem]


Expand Down Expand Up @@ -399,17 +413,18 @@ def _simulate(
config.equilibrate,
config.production,
[reporter],
config.apply_hmr,
)


def _compute_averages(
def _compute_observables(
phase: Phase,
key: SimulationKey,
system: smee.TensorSystem,
force_field: smee.TensorForceField,
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None,
) -> dict[str, torch.Tensor]:
) -> _Observables:
traj_hash = hashlib.sha256(pickle.dumps(key)).hexdigest()
traj_name = f"{phase}-{traj_hash}-frames.msgpack"

Expand All @@ -420,9 +435,12 @@ def _compute_averages(

if cached_path is not None and cached_path.exists():
with contextlib.suppress(smee.mm.NotEnoughSamplesError):
return smee.mm.reweight_ensemble_averages(
means = smee.mm.reweight_ensemble_averages(
system, force_field, cached_path, temperature, pressure
)
stds = {key: smee.utils.tensor_like(torch.nan, means[key]) for key in means}

return _Observables(means, stds)

if cached_path is not None:
_LOGGER.debug(f"unable to re-weight {key}: data exists={cached_path.exists()}")
Expand All @@ -432,80 +450,104 @@ def _compute_averages(
config = default_config(phase, key.temperature, key.pressure)
_simulate(system, force_field, config, output_path)

return smee.mm.compute_ensemble_averages(
system, force_field, output_path, temperature, pressure
return _Observables(
*smee.mm.compute_ensemble_averages(
system, force_field, output_path, temperature, pressure
)
)


def _predict_density(
entry: DataEntry, averages: dict[str, torch.Tensor]
) -> torch.Tensor:
entry: DataEntry, observables: _Observables
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert entry["units"] == "g/mL"
return averages["density"]
return observables.mean["density"], observables.std["density"]


def _predict_hvap(
entry: DataEntry,
averages_bulk: dict[str, torch.Tensor],
averages_vacuum: dict[str, torch.Tensor],
observables_bulk: _Observables,
observables_vacuum: _Observables,
system_bulk: smee.TensorSystem,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor]:
assert entry["units"] == "kcal/mol"

temperature = entry["temperature"] * openmm.unit.kelvin
n_mols = sum(system_bulk.n_copies)

potential_bulk = averages_bulk["potential_energy"] / sum(system_bulk.n_copies)
potential_vacuum = averages_vacuum["potential_energy"]
potential_bulk = observables_bulk.mean["potential_energy"] / n_mols
potential_bulk_std = observables_bulk.std["potential_energy"] / n_mols

potential_vacuum = observables_vacuum.mean["potential_energy"]
potential_vacuum_std = observables_vacuum.std["potential_energy"]

rt = (temperature * openmm.unit.MOLAR_GAS_CONSTANT_R).value_in_unit(
openmm.unit.kilocalorie_per_mole
)
return potential_vacuum - potential_bulk + rt

value = potential_vacuum - potential_bulk + rt
std = torch.sqrt(potential_vacuum_std**2 + potential_bulk_std**2)

return value, std


def _predict_hmix(
entry: DataEntry,
averages_mix: dict[str, torch.Tensor],
averages_0: dict[str, torch.Tensor],
averages_1: dict[str, torch.Tensor],
observables_mix: _Observables,
observables_0: _Observables,
observables_1: _Observables,
system_mix: smee.TensorSystem,
system_0: smee.TensorSystem,
system_1: smee.TensorSystem,
) -> torch.Tensor:
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert entry["units"] == "kcal/mol"

x_0 = system_mix.n_copies[0] / sum(system_mix.n_copies)
n_mols_mix = sum(system_mix.n_copies)
n_mols_0 = sum(system_0.n_copies)
n_mols_1 = sum(system_1.n_copies)

x_0 = system_mix.n_copies[0] / n_mols_mix
x_1 = 1.0 - x_0

enthalpy_mix = averages_mix["enthalpy"] / sum(system_mix.n_copies)
enthalpy_mix = observables_mix.mean["enthalpy"] / n_mols_mix
enthalpy_mix_std = observables_mix.std["enthalpy"] / n_mols_mix

enthalpy_0 = averages_0["enthalpy"] / sum(system_0.n_copies)
enthalpy_1 = averages_1["enthalpy"] / sum(system_1.n_copies)
enthalpy_0 = observables_0.mean["enthalpy"] / n_mols_0
enthalpy_0_std = observables_0.std["enthalpy"] / n_mols_0
enthalpy_1 = observables_1.mean["enthalpy"] / n_mols_1
enthalpy_1_std = observables_1.std["enthalpy"] / n_mols_1

return enthalpy_mix - x_0 * enthalpy_0 - x_1 * enthalpy_1
value = enthalpy_mix - x_0 * enthalpy_0 - x_1 * enthalpy_1
std = torch.sqrt(
enthalpy_mix_std**2
+ x_0**2 * enthalpy_0_std**2
+ x_1**2 * enthalpy_1_std**2
)

return value, std


def _predict(
entry: DataEntry,
keys: dict[str, SimulationKey],
averages: dict[Phase, dict[SimulationKey, dict[str, torch.Tensor]]],
observables: dict[Phase, dict[SimulationKey, _Observables]],
systems: dict[Phase, dict[SimulationKey, smee.TensorSystem]],
):
) -> tuple[torch.Tensor, torch.Tensor]:
if entry["type"] == "density":
value = _predict_density(entry, averages["bulk"][keys["bulk"]])
value = _predict_density(entry, observables["bulk"][keys["bulk"]])
elif entry["type"] == "hvap":
value = _predict_hvap(
entry,
averages["bulk"][keys["bulk"]],
averages["vacuum"][keys["vacuum"]],
observables["bulk"][keys["bulk"]],
observables["vacuum"][keys["vacuum"]],
systems["bulk"][keys["bulk"]],
)
elif entry["type"] == "hmix":
value = _predict_hmix(
entry,
averages["bulk"][keys["bulk"]],
averages["bulk"][keys["bulk_0"]],
averages["bulk"][keys["bulk_1"]],
observables["bulk"][keys["bulk"]],
observables["bulk"][keys["bulk_0"]],
observables["bulk"][keys["bulk_1"]],
systems["bulk"][keys["bulk"]],
systems["bulk"][keys["bulk_0"]],
systems["bulk"][keys["bulk_1"]],
Expand All @@ -523,7 +565,7 @@ def predict(
output_dir: pathlib.Path,
cached_dir: pathlib.Path | None = None,
per_type_scales: dict[DataType, float] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict the properties in a dataset using molecular simulation, or by reweighting
previous simulation data.
Expand All @@ -542,9 +584,9 @@ def predict(
entries: list[DataEntry] = [*descent.utils.dataset.iter_dataset(dataset)]

required_simulations, entry_to_simulation = _plan_simulations(entries, topologies)
averages = {
observables = {
phase: {
key: _compute_averages(
key: _compute_observables(
phase, key, system, force_field, output_dir, cached_dir
)
for key, system in systems.items()
Expand All @@ -553,19 +595,29 @@ def predict(
}

predicted = []
predicted_std = []
reference = []
reference_std = []

per_type_scales = per_type_scales if per_type_scales is not None else {}

for entry, keys in zip(entries, entry_to_simulation):
value = _predict(entry, keys, averages, required_simulations)
value, std = _predict(entry, keys, observables, required_simulations)

type_scale = per_type_scales.get(entry["type"], 1.0)

predicted.append(value * per_type_scales.get(entry["type"], 1.0))
reference.append(
torch.tensor(entry["value"]) * per_type_scales.get(entry["type"], 1.0)
predicted.append(value * type_scale)
predicted_std.append(torch.nan if std is None else std * abs(type_scale))

reference.append(entry["value"] * type_scale)
reference_std.append(
torch.nan if entry["std"] is None else entry["std"] * abs(type_scale)
)

predicted = torch.stack(predicted)
reference = torch.stack(reference).to(predicted.device)
predicted_std = torch.stack(predicted_std)

reference = smee.utils.tensor_like(reference, predicted)
reference_std = smee.utils.tensor_like(reference_std, predicted_std)

return reference, predicted
return reference, reference_std, predicted, predicted_std
Loading

0 comments on commit 65aaf50

Please sign in to comment.