Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Commit

Permalink
Add unit distribution for adding to log prob
Browse files Browse the repository at this point in the history
Summary:
Addresses [#1041](#1041). Imported the Unit Dist (with some minor modifications) from pyro. This will allow users to add terms to the model density:
```
bm.random_variable
def increment_log_prob():
    val = Normal(0., 1.).log_prob(1.)
    return Unit(val)
```

In the future we can wrap this with a `factor` statement.

Reviewed By: neerajprad

Differential Revision: D31516303

fbshipit-source-id: b8dc3245e012788c7d5b468aed26535d1cc1b83e
  • Loading branch information
jpchen authored and facebook-github-bot committed Oct 12, 2021
1 parent c55031f commit 9a7b599
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/distribution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from beanmachine.ppl.distribution.flat import Flat
from beanmachine.ppl.distribution.unit import Unit


__all__ = ["Flat"]
__all__ = ["Flat", "Unit"]
66 changes: 66 additions & 0 deletions src/beanmachine/ppl/distribution/unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.distributions import constraints


def broadcast_shape(*shapes, **kwargs):
"""
Similar to ``np.broadcast()`` but for shapes.
Equivalent to ``np.broadcast(*map(np.empty, shapes)).shape``.
:param tuple shapes: shapes of tensors.
:param bool strict: whether to use extend-but-not-resize broadcasting.
:returns: broadcasted shape
:rtype: tuple
:raises: ValueError
"""
strict = kwargs.pop("strict", False)
reversed_shape = []
for shape in shapes:
for i, size in enumerate(reversed(shape)):
if i >= len(reversed_shape):
reversed_shape.append(size)
elif reversed_shape[i] == 1 and not strict:
reversed_shape[i] = size
elif reversed_shape[i] != size and (size != 1 or strict):
raise ValueError(
"shape mismatch: objects cannot be broadcast to a single shape: {}".format(
" vs ".join(map(str, shapes))
)
)
return tuple(reversed(reversed_shape))


class Unit(torch.distributions.Distribution):
"""
Trivial nonnormalized distribution representing the unit type.
The unit type has a single value with no data, i.e. ``value.numel() == 0``.
This is used for :func:`pyro.factor` statements.
"""

arg_constraints = {"log_factor": constraints.real}
support = constraints.real

def __init__(self, log_factor, validate_args=None):
log_factor = torch.as_tensor(log_factor)
batch_shape = log_factor.shape
event_shape = torch.Size((0,)) # This satisfies .numel() == 0.
self.log_factor = log_factor
super().__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(Unit, _instance)
new.log_factor = self.log_factor.expand(batch_shape)
super(Unit, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=torch.Size()): # noqa: B008
return self.log_factor.new_empty(sample_shape)

def log_prob(self, value):
shape = broadcast_shape(self.batch_shape, value.shape[:-1])
return self.log_factor.expand(shape)

0 comments on commit 9a7b599

Please sign in to comment.