From 83aefe9f74d1dd79ab868dae18188edb7e0691a6 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 14 Oct 2023 17:29:30 +0200 Subject: [PATCH 1/3] tests --- tests/test_mclmc.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/test_mclmc.py b/tests/test_mclmc.py index e804518..e002e9e 100644 --- a/tests/test_mclmc.py +++ b/tests/test_mclmc.py @@ -1,4 +1,7 @@ -import sys +import sys + +from pytest import raises +import pytest sys.path.insert(0, '../../') sys.path.insert(0, './') @@ -31,15 +34,38 @@ def prior_draw(self, key): return jax.random.normal(key, shape = (self.d, ), dtype = 'float64') * 4 -target = StandardGaussian(d = 10) -sampler = Sampler(target, varEwanted = 5e-4) def test_mclmc(): + target = StandardGaussian(d = 10) + sampler = Sampler(target, varEwanted = 5e-4) + + samples1 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0)) samples2 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0)) samples3 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(1)) assert jnp.array_equal(samples1,samples2), "sampler should be pure" assert not jnp.array_equal(samples1,samples3), "this suggests that seed is not being used" + # run without key + sampler.sample(20) # run with multiple chains - sampler.sample(100, 3) + sampler.sample(20, 3) + # run with different output types + # sampler.sample(20, 3, output='ess') + # sampler.sample(20, 3, output='expectation') + sampler.sample(20, 3, output='detailed') + sampler.sample(20, 3, output='normal') + + # run with leapfrog + sampler = Sampler(target, varEwanted = 5e-4, integrator='LF') + sampler.sample(20) + # run without autotune + sampler = Sampler(target, varEwanted = 5e-4, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,) + sampler.sample(20,) + + # with a specific initial point + sampler.sample(20, x_initial=jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0))) + + # running with wrong dimensions causes TypeError + with raises(TypeError) as excinfo: + sampler.sample(20, x_initial=jax.random.normal(shape=(11,), key=jax.random.PRNGKey(0))) From 6be4b1470a317c7fa1d12a376104be72406cb617 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 14 Oct 2023 17:29:56 +0200 Subject: [PATCH 2/3] tests --- tests/test_mclmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_mclmc.py b/tests/test_mclmc.py index e002e9e..e25e3cd 100644 --- a/tests/test_mclmc.py +++ b/tests/test_mclmc.py @@ -51,8 +51,8 @@ def test_mclmc(): # run with multiple chains sampler.sample(20, 3) # run with different output types - # sampler.sample(20, 3, output='ess') - # sampler.sample(20, 3, output='expectation') + sampler.sample(20, 3, output='ess') + sampler.sample(20, 3, output='expectation') sampler.sample(20, 3, output='detailed') sampler.sample(20, 3, output='normal') From 109210e912f4194fa338f49b33b2904c0f13a45d Mon Sep 17 00:00:00 2001 From: = Date: Mon, 16 Oct 2023 11:32:32 +0200 Subject: [PATCH 3/3] tests --- tests/test_mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mclmc.py b/tests/test_mclmc.py index a27a9c1..f312b33 100644 --- a/tests/test_mclmc.py +++ b/tests/test_mclmc.py @@ -49,7 +49,7 @@ def test_mclmc(): sampler.sample(20, 3, output='detailed') sampler.sample(20, 3, output='normal') - with raises(AttributeError) as foo: + with raises(AttributeError) as excinfo: sampler.sample(20, 3, output='ess') # run with leapfrog