Skip to content

Commit

Permalink
Add tensor reporter ctx manager (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonBoothroyd authored Feb 20, 2024
1 parent 2db8f60 commit 4ef91cb
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 108 deletions.
132 changes: 27 additions & 105 deletions examples/md-simulations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,12 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "e178932166969df5",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:39.088878928Z",
"start_time": "2023-11-08T12:44:32.342740491Z"
},
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "70aebb3be8324be38a93fcc339ea9763",
"version_major": 2,
"version_minor": 0
},
"text/plain": []
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n",
"/bin/bash: warning: setlocale: LC_ALL: cannot change locale (en_US.UTF-8)\n"
]
}
],
"outputs": [],
"source": [
"import openff.interchange\n",
"import openff.toolkit\n",
Expand Down Expand Up @@ -102,13 +55,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "38700a4d251b1ab9",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:39.103624993Z",
"start_time": "2023-11-08T12:44:39.097640693Z"
},
"collapsed": false
},
"outputs": [],
Expand All @@ -129,13 +78,9 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"id": "58a584bf7997e194",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:39.383031948Z",
"start_time": "2023-11-08T12:44:39.099505793Z"
},
"collapsed": false
},
"outputs": [],
Expand Down Expand Up @@ -169,13 +114,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "ccba93245cf83ff7",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:39.474284381Z",
"start_time": "2023-11-08T12:44:39.387381350Z"
},
"collapsed": false
},
"outputs": [],
Expand Down Expand Up @@ -232,17 +173,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"id": "d485cce5c1a0fce3",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:39.482417652Z",
"start_time": "2023-11-08T12:44:39.478860928Z"
},
"collapsed": false
},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"import torch\n",
"\n",
"\n",
Expand All @@ -254,37 +193,32 @@
" # box vectors and kinetic energies\n",
" coords, box_vectors = smee.mm.generate_system_coords(system, force_field)\n",
"\n",
" with tempfile.NamedTemporaryFile() as tmp_path:\n",
" # save the simulation output every 1000th frame (2 ps) to a temporary file.\n",
" # we could also save the trajectory more permanently, but as we do nothing\n",
" # with it after computing the averages in this example, we simply want to\n",
" # discard it.\n",
"\n",
" reporter_interval = 1000\n",
"\n",
" with open(tmp_path.name, \"wb\") as tmp_file:\n",
" reporter = smee.mm.TensorReporter(\n",
" tmp_file, reporter_interval, beta, pressure\n",
" )\n",
" interval = 1000\n",
"\n",
" smee.mm.simulate(\n",
" system,\n",
" force_field,\n",
" coords,\n",
" box_vectors,\n",
" equilibrate_config,\n",
" production_config,\n",
" [reporter],\n",
" )\n",
" # save the simulation output every 1000th frame (2 ps) to a temporary file.\n",
" # we could also save the trajectory more permanently, but as we do nothing\n",
" # with it after computing the averages in this example, we simply want to\n",
" # discard it.\n",
" with (\n",
" tempfile.NamedTemporaryFile() as tmp_file,\n",
" smee.mm.tensor_reporter(tmp_file.name, interval, beta, pressure) as reporter,\n",
" ):\n",
" smee.mm.simulate(\n",
" system,\n",
" force_field,\n",
" coords,\n",
" box_vectors,\n",
" equilibrate_config,\n",
" production_config,\n",
" [reporter],\n",
" )\n",
"\n",
" # we can then compute the ensemble averages from the trajectory. generating\n",
" # the trajectory separately from computing the ensemble averages allows us\n",
" # to run the simulation in parallel with other simulations more easily, without\n",
" # having to worry about copying gradients between workers / processes.\n",
" import pathlib\n",
"\n",
" avgs, stds = smee.mm.compute_ensemble_averages(\n",
" system, force_field, pathlib.Path(tmp_path.name), temperature, pressure\n",
" system, force_field, pathlib.Path(tmp_file.name), temperature, pressure\n",
" )\n",
" return avgs"
]
Expand All @@ -304,10 +238,6 @@
"execution_count": null,
"id": "3156bcfc509380f7",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:40.364614982Z",
"start_time": "2023-11-08T12:44:39.485185718Z"
},
"collapsed": false
},
"outputs": [],
Expand Down Expand Up @@ -336,10 +266,6 @@
"execution_count": null,
"id": "38b9a27d7cd06c1a",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:40.369070019Z",
"start_time": "2023-11-08T12:44:40.365546037Z"
},
"collapsed": false
},
"outputs": [],
Expand Down Expand Up @@ -379,10 +305,6 @@
"execution_count": null,
"id": "dd3ccdfe61a0cd09",
"metadata": {
"ExecuteTime": {
"end_time": "2023-11-08T12:44:40.371248178Z",
"start_time": "2023-11-08T12:44:40.370055198Z"
},
"collapsed": false
},
"outputs": [],
Expand Down
3 changes: 2 additions & 1 deletion smee/mm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
compute_ensemble_averages,
reweight_ensemble_averages,
)
from smee.mm._reporters import TensorReporter, unpack_frames
from smee.mm._reporters import TensorReporter, tensor_reporter, unpack_frames

__all__ = [
"compute_ensemble_averages",
Expand All @@ -19,5 +19,6 @@
"NotEnoughSamplesError",
"SimulationConfig",
"TensorReporter",
"tensor_reporter",
"unpack_frames",
]
26 changes: 25 additions & 1 deletion smee/mm/_reporters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""OpenMM simulation reporters"""

import contextlib
import math
import os
import typing

import msgpack
Expand Down Expand Up @@ -37,7 +39,8 @@ def _decoder(obj, chain=None):


class TensorReporter:
"""A reporter which stores coords, box vectors, and kinetic energy using msgpack."""
"""A reporter which stores coords, box vectors, reduced potentials and kinetic
energy using msgpack."""

def __init__(
self,
Expand Down Expand Up @@ -109,3 +112,24 @@ def unpack_frames(

for frame in unpacker:
yield frame


@contextlib.contextmanager
def tensor_reporter(
output_path: os.PathLike,
report_interval: int,
beta: openmm.unit.Quantity,
pressure: openmm.unit.Quantity | None,
) -> TensorReporter:
"""Create a ``TensorReporter`` capable of writing frames to a file.
Args:
output_path: The path to write the frames to.
report_interval: The interval (in steps) at which to write frames.
beta: The inverse temperature the simulation is being run at.
pressure: The pressure the simulation is being run at, or ``None`` if NVT /
vacuum.
"""
with open(output_path, "wb") as output_file:
reporter = TensorReporter(output_file, report_interval, beta, pressure)
yield reporter
15 changes: 14 additions & 1 deletion smee/tests/mm/test_reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import openmm.unit
import pytest

from smee.mm._reporters import TensorReporter, unpack_frames
from smee.mm._reporters import TensorReporter, tensor_reporter, unpack_frames


class TestTensorReporter:
Expand Down Expand Up @@ -81,3 +81,16 @@ def test_report_energy_check(self, potential, contains, mocker):
with pytest.raises(ValueError, match=f"total energy is {contains}"):
reporter = TensorReporter(mocker.MagicMock(), 1, beta, None)
reporter.report(None, mock_state)


def test_tensor_reporter(tmp_path):
output = tmp_path / "frames.msgpack"

beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * 298.15 * openmm.unit.kelvin)

pressure = 1.0 * openmm.unit.atmospheres

with tensor_reporter(output, 2, beta, pressure) as reporter:
assert isinstance(reporter, TensorReporter)

assert output.exists() and output.is_file()

0 comments on commit 4ef91cb

Please sign in to comment.