Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #29

Merged
merged 3 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions sampling/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

lambda_c = 0.1931833275037836 #critical value of the lambda parameter for the minimal norm integrator

grad_evals = {'MN' : 2, 'LF' : 1}



Expand Down Expand Up @@ -93,22 +92,6 @@ def step(x, u, g, eps, sigma):



def hamiltonian(integrator, grad_nlogp, d, sequential = True):

T = update_position(grad_nlogp)
V = update_momentum(d, sequential)

if integrator == "LF": #leapfrog (first updates the velocity)
return leapfrog(d, T, V)

elif integrator== 'MN': #minimal norm integrator (first updates the velocity)
return minimal_norm(d, T, V)

else:
raise Exception("Integrator must be either MN (minimal_norm) or LF (leapfrog)")



def mclmc(hamiltonian_dynamics, partially_refresh_momentum, d):

def step(x, u, g, random_key, L, eps, sigma):
Expand Down Expand Up @@ -168,4 +151,7 @@ def rng_parallel(u, random_key, nu):
return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key


return rng_sequential if sequential else rng_parallel
return rng_sequential if sequential else rng_parallel


grad_evals = {minimal_norm : 2, leapfrog : 1}
100 changes: 57 additions & 43 deletions sampling/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## style note: general preference here for functional style (e.g. global function definitions, purity, code sharing)

from enum import Enum
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -23,11 +24,14 @@ def prior_draw(self, key):

raise Exception("Not implemented")

OutputType = Enum('Output', ['normal', 'detailed', 'expectation', 'ess'])


class Sampler:
"""the MCHMC (q = 0 Hamiltonian) sampler"""

def __init__(self, Target : Target, L = None, eps = None,
integrator = 'MN', varEwanted = 5e-4,
integrator = dynamics.minimal_norm, varEwanted = 5e-4,
diagonal_preconditioning= False,
frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,
):
Expand All @@ -38,7 +42,7 @@ def __init__(self, Target : Target, L = None, eps = None,

eps: initial integration step-size (it is then automaticaly tuned before the sampling starts unless you turn-off the tuning by setting all frac_tune1 and 2 to zero (see below))

integrator: 'LF' (leapfrog) or 'MN' (minimal norm). Typically MN performs better.
integrator: leapfrog or minimal_norm. Typically minimal_norm performs better.

varEwanted: if your posteriors are biased try smaller values (or larger values: perhaps the convergence is too slow). This is perhaps the parameter whose default value is the least well determined.

Expand All @@ -57,10 +61,12 @@ def __init__(self, Target : Target, L = None, eps = None,
self.sigma = jnp.ones(Target.d)

self.integrator = integrator
self.T = dynamics.update_position(self.Target.grad_nlogp)
self.V = dynamics.update_momentum(self.Target.d, sequential=True)

### integrator ###
## NOTE: sigma does not arise from any tuning here: it is a fixed parameter
self.dynamics = dynamics.mclmc(dynamics.hamiltonian(integrator=self.integrator, grad_nlogp=self.Target.grad_nlogp, d=self.Target.d),
self.dynamics = dynamics.mclmc(self.integrator(T=self.T, V=self.V,d=self.Target.d),
dynamics.partially_refresh_momentum(self.Target.d, True), self.Target.d)
self.random_unit_vector = dynamics.random_unit_vector(self.Target.d, True)

Expand Down Expand Up @@ -144,32 +150,27 @@ def get_initial_conditions(self, x_initial, random_key):
key = random_key

### initial conditions ###
if isinstance(x_initial, str):
if x_initial == 'prior': # draw the initial x from the prior
key, prior_key = jax.random.split(key)
x = self.Target.prior_draw(prior_key)
else: # if not 'prior' the x_initial should specify the initial condition
raise KeyError('x_initial = "' + x_initial + '" is not a valid argument. \nIf you want to draw initial condition from a prior use x_initial = "prior", otherwise specify the initial condition with an array')
else: #initial x is given
x = x_initial

l, g = self.Target.grad_nlogp(x)
if x_initial is None: # draw the initial x from the prior
key, prior_key = jax.random.split(key)
x_initial = self.Target.prior_draw(prior_key)

l, g = self.Target.grad_nlogp(x_initial)

u, key = self.random_unit_vector(key)
#u = - g / jnp.sqrt(jnp.sum(jnp.square(g))) #initialize momentum in the direction of the gradient of log p

return x, u, l, g, key
return x_initial, u, l, g, key



def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= None, output = 'normal', thinning= 1):
def sample(self, num_steps, num_chains = 1, x_initial = None, random_key= None, output = OutputType.normal, thinning= 1):
"""Args:
num_steps: number of integration steps to take.

num_chains: number of independent chains, defaults to 1. If different than 1, jax will parallelize the computation with the number of available devices (CPU, GPU, TPU),
as returned by jax.local_device_count().

x_initial: initial condition for x, shape: (d, ). Defaults to 'prior' in which case the initial condition is drawn from the prior distribution (self.Target.prior_draw).
x_initial: initial condition for x, shape: (d, ). Defaults to None in which case the initial condition is drawn from the prior distribution (self.Target.prior_draw).

random_key: jax random seed, defaults to jax.random.PRNGKey(0)

Expand All @@ -193,7 +194,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non

if num_chains == 1:
results = self.single_chain_sample(num_steps, x_initial, random_key, output, thinning) #the function which actually does the sampling
if output == 'ess':
if output == OutputType.ess:
return self.bias_plot(results)

else:
Expand All @@ -205,14 +206,11 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non
else:
key = random_key

if isinstance(x_initial, str):
if x_initial == 'prior': # draw the initial x from the prior
keys_all = jax.random.split(key, num_chains * 2)
x0 = jnp.array([self.Target.prior_draw(keys_all[num_chains+i]) for i in range(num_chains)])
keys = keys_all[:num_chains]
if x_initial is None: # draw the initial x from the prior
keys_all = jax.random.split(key, num_chains * 2)
x0 = jnp.array([self.Target.prior_draw(keys_all[num_chains+i]) for i in range(num_chains)])
keys = keys_all[:num_chains]

else: # if not 'prior' the x_initial should specify the initial condition
raise KeyError('x_initial = "' + x_initial + '" is not a valid argument. \nIf you want to draw initial condition from a prior use x_initial = "prior", otherwise specify the initial condition with an array')
else: #initial x is given
x0 = jnp.copy(x_initial)
keys = jax.random.split(key, num_chains)
Expand All @@ -223,7 +221,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non
if num_cores != 1: #run the chains on parallel cores
parallel_function = jax.pmap(jax.vmap(f))
results = parallel_function(jnp.arange(num_chains).reshape(num_cores, num_chains // num_cores))
if output == 'ess':
if output == OutputType.ess:
return self.bias_plot(results.reshape(num_chains, num_steps))

### reshape results ###
Expand All @@ -242,7 +240,7 @@ def sample(self, num_steps, num_chains = 1, x_initial = 'prior', random_key= Non

results = jax.vmap(f)(jnp.arange(num_chains))

if output == 'ess':
if output == OutputType.ess:
return self.bias_plot(results)

else:
Expand Down Expand Up @@ -270,24 +268,40 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning

### sampling ###

if output == 'normal' or output == 'detailed':
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
if output == 'detailed':
return X, E, L, eps
else:
match output:
case OutputType.normal:
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
return X
elif output == 'expectation':
return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)

elif output == 'ess':
try:
self.Target.variance
except:
raise AttributeError("Target.variance should be defined")
return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

else:
raise ValueError('output = ' + output + ' is not a valid argument for the Sampler.sample')
case OutputType.detailed:
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
return X, E, L, eps
case OutputType.expectation:
return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)
case OutputType.ess:
try:
self.Target.variance
except:
raise AttributeError("Target.variance should be defined")
return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

# if output == OutputType.normal or output == OutputType.detailed:
# X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
# if output == 'detailed':
# return X, E, L, eps
# else:
# return X
# elif output == OutputType.expectation:
# return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)

# elif output == OutputType.ess:
# try:
# self.Target.variance
# except:
# raise AttributeError("Target.variance should be defined")
# return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

# else:
# raise ValueError('output = ' + output + ' is not a valid argument for the Sampler.sample')


### for loops which do the sampling steps: ###
Expand Down
16 changes: 9 additions & 7 deletions tests/test_mclmc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys

from pytest import raises
import pytest
import pytest

from sampling.dynamics import leapfrog
sys.path.insert(0, '../../')
sys.path.insert(0, './')

Expand All @@ -10,7 +12,7 @@
import numpy as np
import matplotlib.pyplot as plt

from sampling.sampler import Sampler, Target
from sampling.sampler import OutputType, Sampler, Target

nlogp = lambda x: 0.5*jnp.sum(jnp.square(x))

Expand Down Expand Up @@ -45,15 +47,15 @@ def test_mclmc():
# run with multiple chains
sampler.sample(20, 3)
# run with different output types
sampler.sample(20, 3, output='expectation')
sampler.sample(20, 3, output='detailed')
sampler.sample(20, 3, output='normal')
sampler.sample(20, 1, output=OutputType.expectation)
sampler.sample(20, 1, output=OutputType.detailed)
sampler.sample(20, 1, output=OutputType.normal)

with raises(AttributeError) as excinfo:
sampler.sample(20, 3, output='ess')
sampler.sample(20, 1, output=OutputType.ess)

# run with leapfrog
sampler = Sampler(target, varEwanted = 5e-4, integrator='LF')
sampler = Sampler(target, varEwanted = 5e-4, integrator=leapfrog)
sampler.sample(20)
# run without autotune
sampler = Sampler(target, varEwanted = 5e-4, frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,)
Expand Down