Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise better error message when using HMC for models with subsample #1293

Closed
kaijennissen opened this issue Jan 20, 2022 · 1 comment · Fixed by #1303
Closed

Raise better error message when using HMC for models with subsample #1293

kaijennissen opened this issue Jan 20, 2022 · 1 comment · Fixed by #1303

Comments

@kaijennissen
Copy link
Contributor

Hi,

I was trying to combine subsampling and MCMC. Is this possible, because I receive the following error AssertionError: Missing random key to generate subsample indices.
I've looked into the code but couldn't figure out where the rng_key should be passed to the _subsample_fn.

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from flax import linen as nn
from jax import random
from numpyro.contrib.module import random_flax_module
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO


def get_data(N: int = 30, N_test: int = 1000):
    X = jnp.asarray(np.random.uniform(-np.pi * 3 / 2, np.pi, size=(N, 1)))
    y = jnp.asarray(np.sin(X) + np.random.normal(loc=0, scale=0.2, size=(N, 1)))
    X_test = jnp.linspace(-np.pi * 2, 2 * np.pi, num=N_test).reshape((-1, 1))
    return X.ravel(), y.ravel(), X_test.ravel()


class Net(nn.Module):
    n_units: int

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_units)(x)
        x = nn.relu(x)
        mean = nn.Dense(1)(x)
        rho = nn.Dense(1)(x)
        return mean.squeeze(), rho.squeeze()


def model(x, y=None, batch_size=None):
    module = Net(n_units=16)
    net = random_flax_module("nn", module, dist.Normal(0, 1.), input_shape=x.shape)
    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
        batch_x = numpyro.subsample(x, event_dim=1)
        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
        mean, rho = net(batch_x)
        sigma = nn.softplus(rho)
        numpyro.sample("obs", dist.Normal(mean, sigma), obs=batch_y)


n_train_data = 5000
X, y, X_test = get_data(N=n_train_data)

# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_feasible)
# svi = SVI(model, guide, numpyro.optim.Adam(5e-3), TraceMeanField_ELBO())


kernel = NUTS(model, max_tree_depth=1)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=4000, num_chains=1)
mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
Traceback (most recent call last):
  File "src/DNN_flax.py", line 101, in <module>
    mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 572, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 383, in _single_chain_mcmc
    collect_vals = fori_collect(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 353, in fori_collect
    vals = jit(_body_fn)(i, vals)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 320, in _body_fn
    val = body_fun(val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 174, in _sample_fn_nojit_args
    return (sampler.sample(state[0], args, kwargs),)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 760, in sample
    return self._sample_fn(state, model_args, model_kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 460, in sample_kernel
    vv_state, energy, num_steps, accept_prob, diverging = _next(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 400, in _nuts_next
    binary_tree = build_tree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1181, in build_tree
    tree, _ = while_loop(_cond_fn, _body_fn, state)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1165, in _body_fn
    tree = _double_tree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 917, in _double_tree
    new_tree = _iterative_build_subtree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1065, in _iterative_build_subtree
    tree, turning, _, _, _ = while_loop(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1007, in _body_fn
    new_leaf = _build_basetree(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 859, in _build_basetree
    z_new, r_new, potential_energy_new, z_new_grad = vv_update(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 298, in update_fn
    potential_energy, z_grad = _value_and_grad(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 247, in _value_and_grad
    return value_and_grad(f)(x)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 227, in potential_energy
    log_joint, model_trace = log_density_(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 53, in log_density
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/handlers.py", line 165, in get_trace
    self(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  [Previous line repeated 1 more time]
  File "src/DNN_flax.py", line 78, in model
    with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 444, in __init__
    self.dim, self._indices = self._subsample(
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 467, in _subsample
    apply_stack(msg)
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 35, in apply_stack
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 385, in _subsample_fn
    assert rng_key is not None, "Missing random key to generate subsample indices."
AssertionError: Missing random key to generate subsample indices.
``
@fehiepsi fehiepsi added the question Further information is requested label Jan 20, 2022
@fehiepsi
Copy link
Member

fehiepsi commented Jan 20, 2022

I think you can try https://num.pyro.ai/en/stable/examples/hmcecs.html or some other methods for tall data https://num.pyro.ai/en/stable/examples/covtype.html

You might also want to try subposterior methods that are outlined in #277

@fehiepsi fehiepsi added warnings & errors and removed question Further information is requested labels Jan 21, 2022
@fehiepsi fehiepsi changed the title Is it possible to use subsampling in combination with MCMC? Raise better error message when using HMC for models with subsample Jan 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants