Skip to content

Commit

Permalink
Merge pull request #25 from JakobRobnik/tests
Browse files Browse the repository at this point in the history
Tests
  • Loading branch information
reubenharry authored Oct 16, 2023
2 parents 90f4e38 + 109210e commit bd96105
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
4 changes: 4 additions & 0 deletions sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,10 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning
return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)

elif output == 'ess':
try:
self.Target.variance
except:
raise AttributeError("Target.variance should be defined")
return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

else:
Expand Down
42 changes: 37 additions & 5 deletions tests/test_mclmc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import sys
import sys

from pytest import raises
import pytest
sys.path.insert(0, '../../')
sys.path.insert(0, './')

Expand All @@ -25,18 +28,47 @@ def prior_draw(self, key):

return jax.random.normal(key, shape = (self.d, ), dtype = 'float64') * 4

target = StandardGaussian(d = 10, nlogp=nlogp)
sampler = Sampler(target, varEwanted = 5e-4)

target_simple = Target(d = 10, nlogp=nlogp)

def test_mclmc():
target = StandardGaussian(d = 10, nlogp=nlogp)
sampler = Sampler(target, varEwanted = 5e-4)


samples1 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0))
samples2 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0))
samples3 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(1))
assert jnp.array_equal(samples1,samples2), "sampler should be pure"
assert not jnp.array_equal(samples1,samples3), "this suggests that seed is not being used"
# run without key
sampler.sample(20)
# run with multiple chains
sampler.sample(100, 3)
sampler.sample(20, 3)
# run with different output types
sampler.sample(20, 3, output='expectation')
sampler.sample(20, 3, output='detailed')
sampler.sample(20, 3, output='normal')

with raises(AttributeError) as excinfo:
sampler.sample(20, 3, output='ess')

# run with leapfrog
sampler = Sampler(target, varEwanted = 5e-4, integrator='LF')
sampler.sample(20)
# run without autotune
sampler = Sampler(target, varEwanted = 5e-4, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,)
sampler.sample(20,)

# with a specific initial point
sampler.sample(20, x_initial=jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0)))

# running with wrong dimensions causes TypeError
with raises(TypeError) as excinfo:
sampler.sample(20, x_initial=jax.random.normal(shape=(11,), key=jax.random.PRNGKey(0)))

# multiple chains
sampler.sample(20, 3)

# simple target
target_simple = Target(d = 10, nlogp=nlogp)
Sampler(target_simple).sample(100, x_initial = jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0)))

0 comments on commit bd96105

Please sign in to comment.