diff --git a/examples/scanvi/scanvi.py b/examples/scanvi/scanvi.py index 3a289afe27..fca7113711 100644 --- a/examples/scanvi/scanvi.py +++ b/examples/scanvi/scanvi.py @@ -104,7 +104,7 @@ def forward(self, x): # Transform the counts x to log space for increased numerical stability. # Note that we only use this transform here; in particular the observation # distribution in the model is a proper count distribution. - x = torch.log(1 + x) + x = torch.log1p(x) h1, h2 = split_in_half(self.fc(x)) z2_loc, z2_scale = h1[..., :-1], softplus(h2[..., :-1]) l_loc, l_scale = h1[..., -1:], softplus(h2[..., -1:])