diff --git a/smee/mm/_ops.py b/smee/mm/_ops.py index bda9d2d..9e90cc7 100644 --- a/smee/mm/_ops.py +++ b/smee/mm/_ops.py @@ -169,7 +169,10 @@ def _compute_frame_observables( The observables for this frame. """ - values = {"potential_energy": potential_energy} + values = { + "potential_energy": potential_energy, + "potential_energy^2": potential_energy**2, + } reduced_potential = beta * potential_energy if not system.is_periodic: @@ -177,7 +180,8 @@ def _compute_frame_observables( return values volume = torch.det(box_vectors) - values["volume"] = volume + + values.update({"volume": volume, "volume^2": volume**2}) total_mass = _compute_mass(system) @@ -185,10 +189,14 @@ def _compute_frame_observables( if pressure is not None: pv_term = volume * pressure + values["enthalpy"] = potential_energy + kinetic_energy + pv_term + values["enthalpy^2"] = values["enthalpy"] ** 2 reduced_potential += beta * pv_term + values["enthalpy_volume"] = values["enthalpy"] * values["volume"] + values["reduced_potential"] = reduced_potential return values @@ -324,6 +332,18 @@ def backward(ctx, *grad_outputs): grads = [None] * len(theta) + energy = values[:, ctx.columns.index("potential_energy")] + volume = ( + None + if "volume" not in ctx.columns + else values[:, ctx.columns.index("volume")] + ) + enthalpy = ( + None + if "enthalpy" not in ctx.columns + else values[:, ctx.columns.index("enthalpy")] + ) + for i in range(len(du_d_theta)): if du_d_theta[i] is None: continue @@ -332,10 +352,21 @@ def backward(ctx, *grad_outputs): avg_d_output_d_theta_i = { "potential_energy": avg_du_d_theta_i, + "potential_energy^2": (2 * energy * du_d_theta[i]).mean(dim=-1), "volume": torch.zeros_like(avg_du_d_theta_i), + "volume^2": torch.zeros_like(avg_du_d_theta_i), "density": torch.zeros_like(avg_du_d_theta_i), "enthalpy": avg_du_d_theta_i, + "enthalpy^2": ( + None + if enthalpy is None + else (2 * enthalpy * du_d_theta[i]).mean(dim=-1) + ), + "enthalpy_volume": ( + None if volume is None else (volume * du_d_theta[i]).mean(dim=-1) + ), } + avg_d_output_d_theta_i = torch.stack( [avg_d_output_d_theta_i[column] for column in ctx.columns], dim=-1 ) @@ -417,6 +448,18 @@ def backward(ctx, *grad_outputs): grads = [None] * len(theta) + energy = values[:, ctx.columns.index("potential_energy")] + volume = ( + None + if "volume" not in ctx.columns + else values[:, ctx.columns.index("volume")] + ) + enthalpy = ( + None + if "enthalpy" not in ctx.columns + else values[:, ctx.columns.index("enthalpy")] + ) + for i in range(len(du_d_theta)): if du_d_theta[i] is None: continue @@ -435,9 +478,15 @@ def backward(ctx, *grad_outputs): d_output_d_theta_i = { "potential_energy": du_d_theta[i], + "potential_energy^2": 2 * energy * du_d_theta[i], "volume": torch.zeros_like(du_d_theta[i]), + "volume^2": torch.zeros_like(du_d_theta[i]), "density": torch.zeros_like(du_d_theta[i]), "enthalpy": du_d_theta[i], + "enthalpy^2": ( + None if enthalpy is None else 2 * enthalpy * du_d_theta[i] + ), + "enthalpy_volume": (None if volume is None else volume * du_d_theta[i]), } d_output_d_theta_i = torch.stack( [d_output_d_theta_i[column] for column in ctx.columns], dim=-1 diff --git a/smee/tests/mm/test_ops.py b/smee/tests/mm/test_ops.py index 4eddcde..90bd654 100644 --- a/smee/tests/mm/test_ops.py +++ b/smee/tests/mm/test_ops.py @@ -144,6 +144,7 @@ def test_compute_frame_observables_non_periodic(mocker): ) assert values == { "potential_energy": expected_potential, + "potential_energy^2": expected_potential**2, "reduced_potential": beta * expected_potential, } @@ -196,10 +197,16 @@ def test_compute_frame_observables(): ) assert values == { "potential_energy": torch.tensor(expected_potential), + "potential_energy^2": torch.tensor(expected_potential**2), "volume": pytest.approx(torch.tensor(expected_volume)), + "volume^2": pytest.approx(torch.tensor(expected_volume**2)), "density": pytest.approx(torch.tensor(expected_density)), "enthalpy": pytest.approx(torch.tensor(expected_enthalpy)), + "enthalpy^2": pytest.approx(torch.tensor(expected_enthalpy**2)), "reduced_potential": pytest.approx(torch.tensor(expected_reduced_potential)), + "enthalpy_volume": pytest.approx( + torch.tensor(expected_enthalpy * expected_volume) + ), } @@ -256,10 +263,13 @@ def test_compute_observables(tmp_path, mock_argon_tensors, mock_argon_params): tensor_system, tensor_ff, file, theta, beta, None ) - assert columns == ["potential_energy"] + assert columns == ["potential_energy", "potential_energy^2"] - assert values.shape == (len(expected_potential), 1) - numpy.allclose(values.numpy().flatten(), expected_potential) + assert values.shape == (len(expected_potential), 2) + assert numpy.allclose( + values.numpy(), + numpy.stack([expected_potential, expected_potential**2], axis=-1), + ) assert reduced_potential.shape == (len(expected_potential),) assert numpy.allclose(reduced_potential.numpy(), beta * expected_potential) @@ -283,9 +293,18 @@ def test_compute_ensemble_averages(mocker, tmp_path, mock_argon_tensors): output_path.write_bytes(b"") mock_outputs = torch.stack( - [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([5.0, 6.0, 20.0])] + [ + torch.tensor([1.0, 1.0, 2.0, 4.0, 3.0]), + torch.tensor([5.0, 25.0, 6.0, 36.0, 20.0]), + ] ) - mock_columns = ["potential_energy", "volume", "density"] + mock_columns = [ + "potential_energy", + "potential_energy^2", + "volume", + "volume^2", + "density", + ] mock_du_d_theta = (torch.tensor([[[9.0, 10.0], [11.0, 12.0]]]), None) mock_compute_observables = mocker.patch( @@ -336,7 +355,7 @@ def test_compute_ensemble_averages(mocker, tmp_path, mock_argon_tensors): beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * temperature) beta = beta.value_in_unit(openmm.unit.kilocalorie_per_mole**-1) - energy, volume, density = mock_outputs[:, 0], mock_outputs[:, 1], mock_outputs[:, 2] + energy, volume, density = mock_outputs[:, 0], mock_outputs[:, 2], mock_outputs[:, 4] du_d_eps = mock_du_d_theta[0][0, 0, :] expected_d_avg_energy_d_eps = du_d_eps.mean(-1) - beta * ( @@ -362,13 +381,22 @@ def test_reweight_ensemble_averages(mocker, tmp_path, mock_argon_tensors): beta = 1.0 / (openmm.unit.MOLAR_GAS_CONSTANT_R * temperature) mock_outputs = torch.stack( - [torch.tensor([1.0, 2.0, 3.0]), torch.tensor([5.0, 6.0, 20.0])] + [ + torch.tensor([1.0, 1.0, 2.0, 4.0, 3.0]), + torch.tensor([5.0, 25.0, 6.0, 36.0, 20.0]), + ] ) mock_reduced = ( beta.value_in_unit(openmm.unit.kilocalories_per_mole**-1) * mock_outputs[:, 0] ) - mock_columns = ["potential_energy", "volume", "density"] + mock_columns = [ + "potential_energy", + "potential_energy^2", + "volume", + "volume^2", + "density", + ] mock_du_d_theta = (torch.tensor([[[-9.0, 10.0], [11.0, -12.0]]]), None) mocker.patch(