Skip to content

Commit

Permalink
debugged positive constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Nov 24, 2023
1 parent 2bc7a8a commit ec0a07c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
8 changes: 4 additions & 4 deletions mclmc/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def step(x, u, g, eps, sigma):
# V T V T V
uu, r1 = V(eps * lambda_c, u, g * sigma)
xx, ll, gg, reflect = T(eps, x, 0.5*uu*sigma)
uu = (1 - 2 * reflect) * uu
uu *= reflect
uu, r2 = V(eps * (1 - 2 * lambda_c), uu, gg * sigma)
xx, ll, gg, reflect = T(eps, xx, 0.5*uu*sigma)
uu = (1 - 2 * reflect) * uu
uu *= reflect
uu, r3 = V(eps * lambda_c, uu, gg * sigma)

#kinetic energy change
Expand All @@ -108,9 +108,9 @@ def step(x, u, g, eps, sigma):
# V T V
uu, r1 = V(eps * 0.5, u, g * sigma)
xx, l, gg, reflect = T(eps, x, uu*sigma)
uu = (1 - 2 * reflect) * uu
uu *= reflect
uu, r2 = V(eps * 0.5, uu, gg * sigma)

# kinetic energy change
kinetic_change = (r1 + r2) * (d-1)

Expand Down
9 changes: 7 additions & 2 deletions mclmc/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,11 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning

### sampling ###

hyp = Parameters(L, eps, sigma)
self.hyp = hyp

if output == OutputType.normal or output == OutputType.detailed:
X, _, E = self.sample_normal(num_steps, MCLMCState(x, u, l, g, key), Parameters(L, eps, sigma), thinning)
X, _, E = self.sample_normal(num_steps, MCLMCState(x, u, l, g, key), hyp, thinning)
if output == OutputType.detailed:
return X, E, L, eps
else:
Expand Down Expand Up @@ -537,8 +539,11 @@ def positive(self, where):
"""
mask = self.to_mask(where)

self.map = lambda x: (jnp.abs(x) * mask + x * (1- mask), x < 0.)
self.map = lambda x: (jnp.abs(x) * mask + x * (1- mask), self.to_reflect((x < 0.) * mask))

def to_reflect(self, mask):
return 1 - 2 * mask


def rectangular(self, where, a, b):
"""Used if some parameters are rectangularly constrained (a< x < b)
Expand Down

0 comments on commit ec0a07c

Please sign in to comment.