You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to combine subsampling and MCMC. Is this possible, because I receive the following error AssertionError: Missing random key to generate subsample indices.
I've looked into the code but couldn't figure out where the rng_key should be passed to the _subsample_fn.
Traceback (most recent call last):
File "src/DNN_flax.py", line 101, in <module>
mcmc.run(x=X, y=y, batch_size=256, rng_key=random.PRNGKey(63547901))
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 572, in run
states_flat, last_state = partial_map_fn(map_args)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 383, in _single_chain_mcmc
collect_vals = fori_collect(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 353, in fori_collect
vals = jit(_body_fn)(i, vals)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 320, in _body_fn
val = body_fun(val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 174, in _sample_fn_nojit_args
return (sampler.sample(state[0], args, kwargs),)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 760, in sample
return self._sample_fn(state, model_args, model_kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 460, in sample_kernel
vv_state, energy, num_steps, accept_prob, diverging = _next(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 400, in _nuts_next
binary_tree = build_tree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1181, in build_tree
tree, _ = while_loop(_cond_fn, _body_fn, state)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1165, in _body_fn
tree = _double_tree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 917, in _double_tree
new_tree = _iterative_build_subtree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1065, in _iterative_build_subtree
tree, turning, _, _, _ = while_loop(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/util.py", line 129, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 1007, in _body_fn
new_leaf = _build_basetree(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 859, in _build_basetree
z_new, r_new, potential_energy_new, z_new_grad = vv_update(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 298, in update_fn
potential_energy, z_grad = _value_and_grad(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/hmc_util.py", line 247, in _value_and_grad
return value_and_grad(f)(x)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 227, in potential_energy
log_joint, model_trace = log_density_(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/infer/util.py", line 53, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/handlers.py", line 165, in get_trace
self(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 87, in __call__
return self.fn(*args, **kwargs)
[Previous line repeated 1 more time]
File "src/DNN_flax.py", line 78, in model
with numpyro.plate("batch", x.shape[0], subsample_size=batch_size):
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 444, in __init__
self.dim, self._indices = self._subsample(
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 467, in _subsample
apply_stack(msg)
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 35, in apply_stack
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/usr/local/Caskroom/miniconda/base/envs/BNN/lib/python3.8/site-packages/numpyro/primitives.py", line 385, in _subsample_fn
assert rng_key is not None, "Missing random key to generate subsample indices."
AssertionError: Missing random key to generate subsample indices.
``
The text was updated successfully, but these errors were encountered:
fehiepsi
changed the title
Is it possible to use subsampling in combination with MCMC?
Raise better error message when using HMC for models with subsample
Jan 21, 2022
Hi,
I was trying to combine subsampling and MCMC. Is this possible, because I receive the following error
AssertionError: Missing random key to generate subsample indices.
I've looked into the code but couldn't figure out where the
rng_key
should be passed to the_subsample_fn
.The text was updated successfully, but these errors were encountered: