Skip to content

Commit

Permalink
Merge branch 'master' of github.com:JakobRobnik/MicroCanonicalHMC
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Oct 16, 2023
2 parents a12b37e + 5286b71 commit 90f4e38
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 6,456 deletions.
396 changes: 0 additions & 396 deletions notebooks/benchmark_sampling.py

This file was deleted.

3,969 changes: 0 additions & 3,969 deletions notebooks/mathematica/ErrorAnalysis.nb

This file was deleted.

1,098 changes: 0 additions & 1,098 deletions notebooks/mathematica/Fokker-Planck.nb

This file was deleted.

438 changes: 0 additions & 438 deletions notebooks/mathematica/Microcanonical_Nose-Hoover.nb

This file was deleted.

395 changes: 0 additions & 395 deletions notebooks/mathematica/poisson_brackets.nb

This file was deleted.

123 changes: 0 additions & 123 deletions notebooks/mathematica/theoretically_worst_case_convex.nb

This file was deleted.

74 changes: 46 additions & 28 deletions notebooks/tutorials/intro_tutorial.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sampling/correlation_length.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from scipy.fftpack import next_fast_len
from scipy.fftpack import next_fast_len #type: ignore



Expand Down
16 changes: 15 additions & 1 deletion sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,26 @@
from . import dynamics
from .correlation_length import ess_corr

class Target():

def __init__(self, d, nlogp):
self.d = d
self.nlogp = nlogp
self.grad_nlogp = jax.value_and_grad(self.nlogp)

def transform(self, x):
return x

def prior_draw(self, key):
"""Args: jax random key
Returns: one random sample from the prior"""

raise Exception("Not implemented")

class Sampler:
"""the MCHMC (q = 0 Hamiltonian) sampler"""

def __init__(self, Target, L = None, eps = None,
def __init__(self, Target : Target, L = None, eps = None,
integrator = 'MN', varEwanted = 5e-4,
diagonal_preconditioning= False,
frac_tune1 = 0.1, frac_tune2 = 0.1, frac_tune3 = 0.1,
Expand Down
42 changes: 42 additions & 0 deletions tests/test_mclmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sys
sys.path.insert(0, '../../')
sys.path.insert(0, './')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt

from sampling.sampler import Sampler, Target

nlogp = lambda x: 0.5*jnp.sum(jnp.square(x))

class StandardGaussian(Target):

def __init__(self, d, nlogp):
Target.__init__(self,d,nlogp)

def transform(self, x):
return x[:2]

def prior_draw(self, key):
"""Args: jax random key
Returns: one random sample from the prior"""

return jax.random.normal(key, shape = (self.d, ), dtype = 'float64') * 4

target = StandardGaussian(d = 10, nlogp=nlogp)
sampler = Sampler(target, varEwanted = 5e-4)

target_simple = Target(d = 10, nlogp=nlogp)

def test_mclmc():
samples1 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0))
samples2 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(0))
samples3 = sampler.sample(100, 1, random_key=jax.random.PRNGKey(1))
assert jnp.array_equal(samples1,samples2), "sampler should be pure"
assert not jnp.array_equal(samples1,samples3), "this suggests that seed is not being used"
# run with multiple chains
sampler.sample(100, 3)

Sampler(target_simple).sample(100, x_initial = jax.random.normal(shape=(10,), key=jax.random.PRNGKey(0)))
13 changes: 6 additions & 7 deletions tests/test_momentum_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

import jax.numpy as jnp

def update_momentum_unstable(d, eps):
def update_momentum_unstable(d):

def update(u, g):
def update(eps, u, g):
g_norm = jnp.linalg.norm(g)
e = - g / g_norm
delta = eps * g_norm / (d-1)
Expand All @@ -26,11 +26,10 @@ def test_momentum_update():
u = jax.random.uniform(key=jax.random.PRNGKey(0),shape=(d,))
u = u / jnp.linalg.norm(u)
g = jax.random.uniform(key=jax.random.PRNGKey(1),shape=(d,))
update_stable = update_momentum(d, eps)
update_unstable = update_momentum_unstable(d, eps)
update1 = update_stable(u, g)[0]
update2 = update_unstable(u, g)
update_stable = update_momentum(d, sequential=True)
update_unstable = update_momentum_unstable(d)
update1 = update_stable(eps, u, g)[0]
update2 = update_unstable(eps, u, g)
print(update1, update2)
assert jnp.allclose(update1,update2)


0 comments on commit 90f4e38

Please sign in to comment.