Skip to content

Commit

Permalink
ensembles in tensordict
Browse files Browse the repository at this point in the history
  • Loading branch information
smorad committed Jul 11, 2023
1 parent 40a467e commit b39d51b
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 0 deletions.
31 changes: 31 additions & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,37 @@ distinguish on a high level parameters and buffers (they are all packed together
make_functional
repopulate_module

Ensembles
---------
The functional approach enables a straightforward ensemble implementation.
We can duplicate and reinitialize model copies using the :class:`tensordict.nn.EnsembleModule`

.. code-block::
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)
.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

EnsembleModule

Tracing and compiling
---------------------

Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TensorDictModuleWrapper,
)
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.functional_modules import (
get_functional,
is_functional,
Expand Down
124 changes: 124 additions & 0 deletions tensordict/nn/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import warnings

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict.nn import make_functional, TensorDictModuleBase
from torch import nn


class EnsembleModule(TensorDictModuleBase):
"""Module that wraps a module and repeats it to form an ensemble.
Args:
module (nn.Module): The nn.module to duplicate and wrap.
num_copies (int): The number of copies of module to make.
parameter_init_function (Callable): A function that takes a module copy and initializes its parameters.
expand_input (bool): Whether to expand the input TensorDict to match the number of copies. This should be
True unless you are chaining ensemble modules together, e.g. EnsembleModule(cnn) -> EnsembleModule(mlp).
If False, EnsembleModule(mlp) will expected the previous module(s) to have already expanded the input.
Examples:
>>> import torch
>>> from torch import nn
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> net = nn.Sequential(nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 2))
>>> mod = TensorDictModule(net, in_keys=['a'], out_keys=['b'])
>>> ensemble = EnsembleModule(mod, num_copies=3)
>>> data = TensorDict({'a': torch.randn(10, 4)}, batch_size=[10])
>>> ensemble(data)
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([3, 10, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 10]),
device=None,
is_shared=False)
To stack EnsembleModules together, we should be mindful of turning off `expand_input` from the second module and on.
Examples:
>>> import torch
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> from torchrl.modules import EnsembleModule
>>> from tensordict import TensorDict
>>> module = TensorDictModule(torch.nn.Linear(2,3), in_keys=['bork'], out_keys=['dork'])
>>> next_module = TensorDictModule(torch.nn.Linear(3,1), in_keys=['dork'], out_keys=['spork'])
>>> e0 = EnsembleModule(module, num_copies=4, expand_input=True)
>>> e1 = EnsembleModule(next_module, num_copies=4, expand_input=False)
>>> seq = TensorDictSequential(e0, e1)
>>> data = TensorDict({'bork': torch.randn(5,2)}, batch_size=[5])
>>> seq(data)
TensorDict(
fields={
bork: Tensor(shape=torch.Size([4, 5, 2]), device=cpu, dtype=torch.float32, is_shared=False),
dork: Tensor(shape=torch.Size([4, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
spork: Tensor(shape=torch.Size([4, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4, 5]),
device=None,
is_shared=False)
"""

def __init__(
self,
module: TensorDictModuleBase,
num_copies: int,
expand_input: bool = True,
):
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
params_td = make_functional(module).expand(num_copies).to_tensordict()

self.module = module
self.params_td = params_td
self.ensemble_parameters = nn.ParameterList(
list(self.params_td.values(True, True))
)
if expand_input:
self.vmapped_forward = torch.vmap(self.module, (None, 0))
else:
self.vmapped_forward = torch.vmap(self.module, 0)

self.reset_parameters_recursive(self.params_td)

def forward(self, tensordict: TensorDict) -> TensorDict:
return self.vmapped_forward(tensordict, self.params_td)

def reset_parameters_recursive(
self, parameters: TensorDictBase = None
) -> TensorDictBase:
"""Resets the parameters of all the copies of the module.
Args:
parameters (TensorDict): A TensorDict of parameters for self.module. The batch dimension(s) of the tensordict
denote the number of module copies to reset.
Returns:
A TensorDict of pointers to the reset parameters.
"""
if parameters is None:
raise ValueError(
"Ensembles are functional and require passing a TensorDict of parameters to reset_parameters_recursive"
)
if parameters.ndim:
params_pointers = []
for params_copy in parameters.unbind(0):
self.reset_parameters_recursive(params_copy)
params_pointers.append(params_copy)
return torch.stack(params_pointers, -1)
else:
# In case the user has added other neural networks to the EnsembleModule
# besides those in self.module
child_mods = [
mod
for name, mod in self.named_children()
if name != "module" and name != "ensemble_parameters"
]
if child_mods:
warnings.warn(
"EnsembleModule.reset_parameters_recursive() only resets parameters of self.module, but other parameters were detected. These parameters will not be reset."
)
# Reset all self.module descendant parameters
return self.module.reset_parameters_recursive(parameters)
97 changes: 97 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.

import argparse
import copy
import pickle
import unittest
import warnings

import pytest
Expand All @@ -22,6 +24,7 @@
)
from tensordict.nn.common import TensorDictModule, TensorDictModuleWrapper
from tensordict.nn.distributions import Delta, NormalParamExtractor, NormalParamWrapper
from tensordict.nn.ensemble import EnsembleModule
from tensordict.nn.functional_modules import is_functional, make_functional
from tensordict.nn.probabilistic import InteractionType, set_interaction_type
from tensordict.nn.utils import set_skip_existing, skip_existing
Expand Down Expand Up @@ -2686,6 +2689,100 @@ def test_nested_keys_probabilistic_normal(log_prob_key):
assert td_out["sample_log_prob"].shape == (3, 4, 1)


class TestEnsembleModule:
def test_init(self):
"""Ensure that we correctly initialize copied weights s.t. they are not identical
to the original weights."""
torch.manual_seed(0)
module = TensorDictModule(
nn.Sequential(
nn.Linear(2, 3),
nn.ReLU(),
nn.Linear(3, 1),
),
in_keys=["a"],
out_keys=["b"],
)
mod = EnsembleModule(module, num_copies=2)
for param in mod.ensemble_parameters:
p0, p1 = param.unbind(0)
assert not torch.allclose(
p0, p1
), f"Ensemble params were not initialized correctly {p0}, {p1}"

@pytest.mark.parametrize(
"net",
[
nn.Linear(1, 1),
nn.Sequential(nn.Linear(1, 1)),
nn.Sequential(nn.Linear(1, 1), nn.ReLU(), nn.Linear(1, 1)),
],
)
def test_siso_forward(self, net):
"""Ensure that forward works for a single input and output"""
module = TensorDictModule(
net,
in_keys=["bork"],
out_keys=["dork"],
)
mod = EnsembleModule(module, num_copies=2)
td = TensorDict({"bork": torch.randn(5, 1)}, batch_size=[5])
out = mod(td)
assert "dork" in out.keys(), "Ensemble forward failed to write keys"
assert out["dork"].shape == torch.Size(
[2, 5, 1]
), "Ensemble forward failed to expand input"
outs = out["dork"].unbind(0)
assert not torch.allclose(outs[0], outs[1]), "Outputs should be different"

@pytest.mark.parametrize(
"net",
[
nn.Linear(1, 1),
nn.Sequential(nn.Linear(1, 1)),
nn.Sequential(nn.Linear(1, 1), nn.ReLU(), nn.Linear(1, 1)),
],
)
def test_chained_ensembles(self, net):
"""Ensure that the expand_input argument works"""
module = TensorDictModule(net, in_keys=["bork"], out_keys=["dork"])
next_module = TensorDictModule(
copy.deepcopy(net), in_keys=["dork"], out_keys=["spork"]
)
e0 = EnsembleModule(module, num_copies=4, expand_input=True)
e1 = EnsembleModule(next_module, num_copies=4, expand_input=False)
seq = TensorDictSequential(e0, e1)
td = TensorDict({"bork": torch.randn(5, 1)}, batch_size=[5])
out = seq(td)

for out_key in ["dork", "spork"]:
assert out_key in out.keys(), f"Ensemble forward failed to write {out_key}"
assert out[out_key].shape == torch.Size(
[4, 5, 1]
), f"Ensemble forward failed to expand input for {out_key}"
same_outputs = torch.isclose(
out[out_key].repeat(4, 1, 1), out[out_key].repeat_interleave(4, dim=0)
).reshape(4, 4, 5, 1)
mask_out_diags = torch.eye(4).logical_not()
assert not torch.any(
same_outputs[mask_out_diags]
), f"Module ensemble outputs should be different for {out_key}"

def test_reset_once(self):
"""Ensure we only call reset_parameters() once per ensemble member"""
lin = nn.Linear(1, 1)
lin.reset_parameters = unittest.mock.Mock()
module = TensorDictModule(
nn.Sequential(lin),
in_keys=["a"],
out_keys=["b"],
)
EnsembleModule(module, num_copies=2)
assert (
lin.reset_parameters.call_count == 2
), f"Reset parameters called {lin.reset_parameters.call_count} times should be 2"


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit b39d51b

Please sign in to comment.