From 8522bb6e0058e03e2409c3c0f8dee9110a20075f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Jun 2024 13:26:49 +0100 Subject: [PATCH] [BugFix] Fix deterministic fallback when the dist has no support (#830) --- tensordict/nn/probabilistic.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 4a3e67a5f..bf8ac049c 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -472,9 +472,15 @@ def _dist_sample( try: return dist.deterministic_sample except AttributeError: - fallback = ( - "mean" if isinstance(dist.support, D.constraints._Real) else "mode" - ) + try: + support = dist.support + fallback = ( + "mean" if isinstance(support, D.constraints._Real) else "mode" + ) + except NotImplementedError: + # Some custom dists don't have a support + # We arbitrarily fall onto 'mean' in these cases + fallback = "mean" try: if fallback == "mean": return dist.mean