Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for collapse handler #809

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
4 changes: 3 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import COERCIONS
from numpyro.distributions.distribution import COERCIONS, ExpandedDistribution
from numpyro.primitives import _PYRO_STACK, Messenger, apply_stack, plate
from numpyro.util import not_jax_tracer

Expand Down Expand Up @@ -268,6 +268,8 @@ def process_message(self, msg):
if msg["type"] == "sample":
if msg["value"] is None:
msg["value"] = msg["name"]
if isinstance(msg["fn"], ExpandedDistribution):
msg["fn"] = msg["fn"].base_dist

if isinstance(msg["fn"], Funsor) or isinstance(msg["value"], (str, Funsor)):
msg["stop"] = True
Expand Down
4 changes: 3 additions & 1 deletion numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.distributions.distribution import Distribution
from numpyro.util import identity

_PYRO_STACK = []
Expand Down Expand Up @@ -316,7 +317,8 @@ def process_message(self, msg):
cond_indep_stack = msg['cond_indep_stack']
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
if msg['type'] == 'sample':
# only expand if fn is Distribution, not a Funsor
if msg['type'] == 'sample' and isinstance(msg['fn'], Distribution):
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
expected_shape = self._get_batch_shape(cond_indep_stack)
dist_batch_shape = msg['fn'].batch_shape
if 'sample_shape' in msg['kwargs']:
Expand Down
86 changes: 85 additions & 1 deletion test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,6 @@ def guide():
svi.update(svi_state)


@pytest.mark.xfail(reason="missing pattern in Funsor")
def test_collapse_beta_binomial_plate():
data = np.array([0., 1., 5., 5.])

Expand All @@ -591,6 +590,91 @@ def guide():
svi.update(svi_state)


def test_collapse_normal_normal():
data = np.array(0.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_normal_plate():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
y = numpyro.sample("y", dist.Normal(x, 1.))
with handlers.plate("data", len(data)):
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_plate_normal():
data = np.arange(5.)

def model():
x = numpyro.sample("x", dist.Normal(0, 1))
with handlers.collapse():
with handlers.plate("data", len(data)):
y = numpyro.sample("y", dist.Normal(x, 1.))
numpyro.sample("z", dist.Normal(y, 1.), obs=data)

def guide():
loc = numpyro.param("loc", 0.)
scale = numpyro.param("scale", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Normal(loc, scale))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_collapse_normal_mvn_mvn():
T, d, S = 5, 2, 3
data = jnp.ones((T, S))

def model():
x = numpyro.sample("x", dist.Exponential(1))
with handlers.collapse():
with numpyro.plate("d", d):
# TODO: verify that to_event works here
beta0 = numpyro.sample("beta0", dist.Normal(0, 1).expand([S]).to_event(1))
# TODO: address beta0 is a str, which cannot do infer_param_domain
beta = numpyro.sample("beta", dist.MultivariateNormal(beta0, jnp.eye(S)))
# FIXME: beta is a string here, how to apply numeric operators
mean = jnp.ones((T, d)) @ beta
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
with numpyro.plate("data", T, dim=-2):
numpyro.sample("obs", dist.MultivariateNormal(mean, jnp.eye(S)), obs=data)

def guide():
rate = numpyro.param("rate", 1., constraint=constraints.positive)
numpyro.sample("x", dist.Exponential(rate))

svi = SVI(model, guide, numpyro.optim.Adam(1), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
svi.update(svi_state)


def test_prng_key():
assert numpyro.prng_key() is None

Expand Down