diff --git a/tensordict/nn/ensemble.py b/tensordict/nn/ensemble.py index 749ee8fef..bd6841948 100644 --- a/tensordict/nn/ensemble.py +++ b/tensordict/nn/ensemble.py @@ -1,8 +1,9 @@ import warnings import torch -from tensordict import TensorDict, TensorDictBase -from tensordict.nn import make_functional, TensorDictModuleBase +from tensordict import TensorDict +from .common import TensorDictBase, TensorDictModuleBase +from .functional_modules import make_functional from torch import nn