Skip to content

Commit

Permalink
Merge pull request #30 from JakobRobnik/case3
Browse files Browse the repository at this point in the history
Remove case
  • Loading branch information
reubenharry authored Oct 16, 2023
2 parents f6b1687 + 4e2016f commit 2ab7430
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,38 +273,24 @@ def single_chain_sample(self, num_steps, x_initial, random_key, output, thinning

### sampling ###

match output:
case OutputType.normal:
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
return X
case OutputType.detailed:
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)

if output == OutputType.normal or output == OutputType.detailed:
X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
if output == 'detailed':
return X, E, L, eps
case OutputType.expectation:
return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)
case OutputType.ess:
return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

# if output == OutputType.normal or output == OutputType.detailed:
# X, _, E = self.sample_normal(num_steps, x, u, l, g, key, L, eps, sigma, thinning)
# if output == 'detailed':
# return X, E, L, eps
# else:
# return X
# elif output == OutputType.expectation:
# return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)

# elif output == OutputType.ess:
# try:
# self.Target.variance
# except:
# raise AttributeError("Target.variance should be defined")
# return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)

# else:
# raise ValueError('output = ' + output + ' is not a valid argument for the Sampler.sample')
else:
return X
elif output == OutputType.expectation:
return self.sample_expectation(num_steps, x, u, l, g, key, L, eps, sigma)

elif output == OutputType.ess:
try:
self.Target.variance
except:
raise AttributeError("Target.variance should be defined")
return self.sample_ess(num_steps, x, u, l, g, key, L, eps, sigma)


### for loops which do the sampling steps: ###

def sample_normal(self, num_steps, x, u, l, g, random_key, L, eps, sigma, thinning):
Expand Down

0 comments on commit 2ab7430

Please sign in to comment.