diff --git a/sampling/dynamics.py b/sampling/dynamics.py index 639a6f5..95783d5 100644 --- a/sampling/dynamics.py +++ b/sampling/dynamics.py @@ -5,7 +5,6 @@ lambda_c = 0.1931833275037836 #critical value of the lambda parameter for the minimal norm integrator -grad_evals = {'MN' : 2, 'LF' : 1} @@ -93,22 +92,6 @@ def step(x, u, g, eps, sigma): -def hamiltonian(integrator, grad_nlogp, d, sequential = True): - - T = update_position(grad_nlogp) - V = update_momentum(d, sequential) - - if integrator == "LF": #leapfrog (first updates the velocity) - return leapfrog(d, T, V) - - elif integrator== 'MN': #minimal norm integrator (first updates the velocity) - return minimal_norm(d, T, V) - - else: - raise Exception("Integrator must be either MN (minimal_norm) or LF (leapfrog)") - - - def mclmc(hamiltonian_dynamics, partially_refresh_momentum, d): def step(x, u, g, random_key, L, eps, sigma): @@ -168,4 +151,7 @@ def rng_parallel(u, random_key, nu): return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key - return rng_sequential if sequential else rng_parallel \ No newline at end of file + return rng_sequential if sequential else rng_parallel + + +grad_evals = {minimal_norm : 2, leapfrog : 1} diff --git a/sampling/sampler.py b/sampling/sampler.py index 1a819bc..9561203 100644 --- a/sampling/sampler.py +++ b/sampling/sampler.py @@ -1,5 +1,6 @@ ## style note: general preference here for functional style (e.g. global function definitions, purity, code sharing) +from enum import Enum import jax import jax.numpy as jnp import numpy as np @@ -23,11 +24,14 @@ def prior_draw(self, key): raise Exception("Not implemented") +OutputType = Enum('Output', ['normal', 'detailed', 'expectation', 'ess']) + + class Sampler: """the MCHMC (q = 0 Hamiltonian) sampler""" def __init__(self, Target : Target, L = None, eps = None, - integrator = 'MN', varEwanted = 5e-4, + integrator = dynamics.minimal_norm, varEwanted = 5e-4, diagonal_preconditioning= False, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1, ): @@ -38,7 +42,7 @@ def __init__(self, Target : Target, L = None, eps = None, eps: initial integration step-size (it is then automaticaly tuned before the sampling starts unless you turn-off the tuning by setting all frac_tune1 and 2 to zero (see below)) - integrator: 'LF' (leapfrog) or 'MN' (minimal norm). Typically MN performs better. + integrator: leapfrog or minimal_norm. Typically minimal_norm performs better. varEwanted: if your posteriors are biased try smaller values (or larger values: perhaps the convergence is too slow). This is perhaps the parameter whose default value is the least well determined. @@ -57,10 +61,12 @@ def __init__(self, Target : Target, L = None, eps = None, self.sigma = jnp.ones(Target.d) self.integrator = integrator + self.T = dynamics.update_position(self.Target.grad_nlogp) + self.V = dynamics.update_momentum(self.Target.d, sequential=True) ### integrator ### ## NOTE: sigma does not arise from any tuning here: it is a fixed parameter - self.dynamics = dynamics.mclmc(dynamics.hamiltonian(integrator=self.integrator, grad_nlogp=self.Target.grad_nlogp, d=self.Target.d), + self.dynamics = dynamics.mclmc(self.integrator(T=self.T, V=self.V,d=self.Target.d), dynamics.partially_refresh_momentum(self.Target.d, True), self.Target.d) self.random_unit_vector = dynamics.random_unit_vector(self.Target.d, True) @@ -144,32 +150,27 @@ def get_initial_conditions(self, x_initial, random_key): key = random_key ### initial conditions ### - if isinstance(x_initial, str): - if x_initial == 'prior': # draw the initial x from the prior - key, prior_key = jax.random.split(key) - x = self.Target.prior_draw(prior_key) - else: # if not 'prior' the x_initial should specify the initial condition - raise KeyError('x_initial = "' + x_initial + '" is not a valid argument. \nIf you want to draw initial condition from a prior use x_initial = "prior", otherwise specify the initial condition with an array') - else: #initial x is given - x = x_initial - - l, g = self.Target.grad_nlogp(x) + if x_initial is None: # draw the initial x from the prior + key, prior_key = jax.random.split(key) + x_initial = self.Target.prior_draw(prior_key) + + l, g = self.Target.grad_nlogp(x_initial) u, key = self.random_unit_vector(key) #u = - g / jnp.sqrt(jnp.sum(jnp.square(g))) #initialize momentum in the direction of the gradient of log p - return x, u, l, g, key + return x_initial, u, l, g, key - def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= None, output = 'normal', thinning= 1): + def sample(self, num_steps, num_chains = 1, x_initial = None, random_key= None, output = OutputType.normal, thinning= 1): """Args: num_steps: number of integration steps to take. num_chains: number of independent chains, defaults to 1. If different than 1, jax will parallelize the computation with the number of available devices (CPU, GPU, TPU), as returned by jax.local_device_count(). - x_initial: initial condition for x, shape: (d, ). Defaults to 'prior' in which case the initial condition is drawn from the prior distribution (self.Target.prior_draw). + x_initial: initial condition for x, shape: (d, ). Defaults to None in which case the initial condition is drawn from the prior distribution (self.Target.prior_draw). random_key: jax random seed, defaults to jax.random.PRNGKey(0) @@ -193,7 +194,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non if num_chains == 1: results = self.single_chain_sample(num_steps, x_initial, random_key, output, thinning) #the function which actually does the sampling - if output == 'ess': + if output == OutputType.ess: return self.bias_plot(results) else: @@ -205,14 +206,11 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non else: key = random_key - if isinstance(x_initial, str): - if x_initial == 'prior': # draw the initial x from the prior - keys_all = jax.random.split(key, num_chains * 2) - x0 = jnp.array([self.Target.prior_draw(keys_all[num_chains+i]) for i in range(num_chains)]) - keys = keys_all[:num_chains] + if x_initial is None: # draw the initial x from the prior + keys_all = jax.random.split(key, num_chains * 2) + x0 = jnp.array([self.Target.prior_draw(keys_all[num_chains+i]) for i in range(num_chains)]) + keys = keys_all[:num_chains] - else: # if not 'prior' the x_initial should specify the initial condition - raise KeyError('x_initial = "' + x_initial + '" is not a valid argument. \nIf you want to draw initial condition from a prior use x_initial = "prior", otherwise specify the initial condition with an array') else: #initial x is given x0 = jnp.copy(x_initial) keys = jax.random.split(key, num_chains) @@ -223,7 +221,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non if num_cores != 1: #run the chains on parallel cores parallel_function = jax.pmap(jax.vmap(f)) results = parallel_function(jnp.arange(num_chains).reshape(num_cores, num_chains // num_cores)) - if output == 'ess': + if output == OutputType.ess: return self.bias_plot(results.reshape(num_chains, num_steps)) ### reshape results ### @@ -242,7 +240,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non results = jax.vmap(f)(jnp.arange(num_chains)) - if output == 'ess': + if output == OutputType.ess: return self.bias_plot(results) else: @@ -270,24 +268,40 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning ### sampling ### - if output == 'normal' or output == 'detailed': - X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning) - if output == 'detailed': - return X, E, L, eps - else: + match output: + case OutputType.normal: + X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning) return X - elif output == 'expectation': - return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma) - - elif output == 'ess': - try: - self.Target.variance - except: - raise AttributeError("Target.variance should be defined") - return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma) - - else: - raise ValueError('output = ' + output + ' is not a valid argument for the Sampler.sample') + case OutputType.detailed: + X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning) + return X, E, L, eps + case OutputType.expectation: + return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma) + case OutputType.ess: + try: + self.Target.variance + except: + raise AttributeError("Target.variance should be defined") + return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma) + + # if output == OutputType.normal or output == OutputType.detailed: + # X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning) + # if output == 'detailed': + # return X, E, L, eps + # else: + # return X + # elif output == OutputType.expectation: + # return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma) + + # elif output == OutputType.ess: + # try: + # self.Target.variance + # except: + # raise AttributeError("Target.variance should be defined") + # return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma) + + # else: + # raise ValueError('output = ' + output + ' is not a valid argument for the Sampler.sample') ### for loops which do the sampling steps: ### diff --git a/tests/test_mclmc.py b/tests/test_mclmc.py index f312b33..cb1fd15 100644 --- a/tests/test_mclmc.py +++ b/tests/test_mclmc.py @@ -1,7 +1,9 @@ import sys from pytest import raises -import pytest +import pytest + +from sampling.dynamics import leapfrog sys.path.insert(0, '../../') sys.path.insert(0, './') @@ -10,7 +12,7 @@ import numpy as np import matplotlib.pyplot as plt -from sampling.sampler import Sampler, Target +from sampling.sampler import OutputType, Sampler, Target nlogp = lambda x: 0.5*jnp.sum(jnp.square(x)) @@ -45,15 +47,15 @@ def test_mclmc(): # run with multiple chains sampler.sample(20, 3) # run with different output types - sampler.sample(20, 3, output='expectation') - sampler.sample(20, 3, output='detailed') - sampler.sample(20, 3, output='normal') + sampler.sample(20, 1, output=OutputType.expectation) + sampler.sample(20, 1, output=OutputType.detailed) + sampler.sample(20, 1, output=OutputType.normal) with raises(AttributeError) as excinfo: - sampler.sample(20, 3, output='ess') + sampler.sample(20, 1, output=OutputType.ess) # run with leapfrog - sampler = Sampler(target, varEwanted = 5e-4, integrator='LF') + sampler = Sampler(target, varEwanted = 5e-4, integrator=leapfrog) sampler.sample(20) # run without autotune sampler = Sampler(target, varEwanted = 5e-4, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,)