diff --git a/sampling/sampler.py b/sampling/sampler.py index 2950557..1a819bc 100644 --- a/sampling/sampler.py +++ b/sampling/sampler.py @@ -280,6 +280,10 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma) elif output == 'ess': + try: + self.Target.variance + except: + raise AttributeError("Target.variance should be defined") return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma) else: diff --git a/tests/test_mclmc.py b/tests/test_mclmc.py index fd75b01..f312b33 100644 --- a/tests/test_mclmc.py +++ b/tests/test_mclmc.py @@ -1,4 +1,7 @@ -import sys +import sys + +from pytest import raises +import pytest sys.path.insert(0, '../../') sys.path.insert(0, './') @@ -25,18 +28,47 @@ def prior_draw(self, key): return jax.random.normal(key, shape = (self.d, ), dtype = 'float64') * 4 -target = StandardGaussian(d = 10, nlogp=nlogp) -sampler = Sampler(target, varEwanted = 5e-4) -target_simple = Target(d = 10, nlogp=nlogp) def test_mclmc(): + target = StandardGaussian(d = 10, nlogp=nlogp) + sampler = Sampler(target, varEwanted = 5e-4) + + samples1 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0)) samples2 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0)) samples3 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(1)) assert jnp.array_equal(samples1,samples2), "sampler should be pure" assert not jnp.array_equal(samples1,samples3), "this suggests that seed is not being used" + # run without key + sampler.sample(20) # run with multiple chains - sampler.sample(100, 3) + sampler.sample(20, 3) + # run with different output types + sampler.sample(20, 3, output='expectation') + sampler.sample(20, 3, output='detailed') + sampler.sample(20, 3, output='normal') + + with raises(AttributeError) as excinfo: + sampler.sample(20, 3, output='ess') + + # run with leapfrog + sampler = Sampler(target, varEwanted = 5e-4, integrator='LF') + sampler.sample(20) + # run without autotune + sampler = Sampler(target, varEwanted = 5e-4, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,) + sampler.sample(20,) + # with a specific initial point + sampler.sample(20, x_initial=jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0))) + + # running with wrong dimensions causes TypeError + with raises(TypeError) as excinfo: + sampler.sample(20, x_initial=jax.random.normal(shape=(11,), key=jax.random.PRNGKey(0))) + + # multiple chains + sampler.sample(20, 3) + + # simple target + target_simple = Target(d = 10, nlogp=nlogp) Sampler(target_simple).sample(100, x_initial = jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0)))