Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getparameters and setparameters!! #86

Merged
merged 24 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e2bdfb7
added state_from_transition, parameters and setparameters!!
torfjelde Oct 21, 2021
7fa8de0
Update src/AbstractMCMC.jl
torfjelde Oct 23, 2021
0a4fd17
renamed state_from_transition to updatestate!!
torfjelde Nov 17, 2021
28bdf91
adhere to julia convention
torfjelde Nov 17, 2021
86a7826
added docs
torfjelde Nov 17, 2021
e19cea7
fixed docs
torfjelde Nov 17, 2021
d86499f
fixed docs
torfjelde Nov 17, 2021
bce436d
added example for why updatestate!! is useful
torfjelde Nov 17, 2021
21f4d56
improved MixtureState example
torfjelde Nov 17, 2021
de0e5b2
further improvements to docs
torfjelde Nov 17, 2021
23b9119
renamed parameters and setparameters!! to values and setvalues!!
torfjelde Nov 19, 2021
b9f476c
fixed typo in docs
torfjelde Nov 19, 2021
f7b6096
fixed documenting values
torfjelde Nov 19, 2021
4ca57b0
improved and fixed some bugs in docs
torfjelde Nov 19, 2021
abebd59
fixed typo in docs
torfjelde Nov 19, 2021
d1d4642
renamed values and setvalues!! to realize and realize!!
torfjelde Dec 7, 2021
c6c9554
added model to updatestate!!
torfjelde Dec 7, 2021
d9f8585
Merge branch 'master' into tor/state-transition-related
torfjelde Oct 24, 2023
600d36c
Apply suggestions from code review
torfjelde Oct 10, 2024
1bfbef1
Update docs/src/api.md
torfjelde Oct 10, 2024
d9480d1
Apply suggestions from code review
torfjelde Oct 10, 2024
3f861bf
Merge branch 'master' into tor/state-transition-related
torfjelde Oct 10, 2024
ddb588c
Update docs/src/api.md
torfjelde Oct 10, 2024
d6ab10a
version bump
sunxd3 Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probabilistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "5.4.0"
version = "5.5.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
141 changes: 141 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,144 @@ For chains of this type, AbstractMCMC defines the following two methods.
AbstractMCMC.chainscat
AbstractMCMC.chainsstack
```

## Interacting with states of samplers

To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods:
```@docs
AbstractMCMC.getparams
AbstractMCMC.setparams!!
```
These methods can also be useful for implementing samplers which wraps some inner samplers, e.g. a mixture of samplers.

### Example: `MixtureSampler`

In a `MixtureSampler` we need two things:
- `components`: collection of samplers.
- `weights`: collection of weights representing the probability of choosing the corresponding sampler.

```julia
struct MixtureSampler{W,C} <: AbstractMCMC.AbstractSampler
components::C
weights::W
end
```

To implement the state, we need to keep track of a couple of things:
- `index`: the index of the sampler used in this `step`.
- `states`: the current states of _all_ the components.
We need to keep track of the states of _all_ components rather than just the state for the sampler we used previously.
The reason is that lots of samplers keep track of more than just the previous realizations of the variables, e.g. in `AdvancedHMC.jl` we keep track of the momentum used, the metric used, etc.


```julia
struct MixtureState{S}
index::Int
states::S
end
```
The `step` for a `MixtureSampler` is defined by the following generative process
```math
\begin{aligned}
i &\sim \mathrm{Categorical}(w_1, \dots, w_k) \\
X_t &\sim \mathcal{K}_i(\cdot \mid X_{t - 1})
\end{aligned}
```
where ``\mathcal{K}_i`` denotes the i-th kernel/sampler, and ``w_i`` denotes the weight/probability of choosing the i-th sampler.
[`AbstractMCMC.getparams`](@ref) and [`AbstractMCMC.setparams!!`](@ref) comes into play in defining/computing ``\mathcal{K}_i(\cdot \mid X_{t - 1})`` since ``X_{t - 1}`` could be coming from a different sampler.

If we let `state` be the current `MixtureState`, `i` the current component, and `i_prev` is the previous component we sampled from, then this translates into the following piece of code:

```julia
# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
state_current = AbstractMCMC.setparams!!(
state.states[i],
AbstractMCMC.getparams(state.states[i_prev]),
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, sampler_state;
kwargs...
)
```

The full [`AbstractMCMC.step`](@ref) implementation would then be something like:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler, state; kwargs...)
# Sample the component to use in this `step`.
i = rand(Categorical(sampler.weights))
sampler_current = sampler.components[i]

# Update the corresponding state, i.e. `state.states[i]`, using
# the state and transition from the previous iteration.
i_prev = state.index
state_current = AbstractMCMC.setparams!!(
state.states[i],
AbstractMCMC.getparams(state.states[i_prev]),
)

# Take a `step` for this sampler using the updated state.
transition, state_current = AbstractMCMC.step(
rng, model, sampler_current, state_current;
kwargs...
)

# Create the new states.
# NOTE: Code below will result in `states_new` being a `Vector`.
# If we wanted to allow usage of alternative containers, e.g. `Tuple`,
# it would be better to use something like `@set states[i] = state_current`
# where `@set` is from Setfield.jl.
states_new = map(1:length(state.states)) do j
if j == i
# Replace the i-th state with the new one.
state_current
else
# Otherwise we just carry over the previous ones.
state.states[j]
end
end

# Create the new `MixtureState`.
state_new = MixtureState(i, states_new)

return transition, state_new
end
```

And for the initial [`AbstractMCMC.step`](@ref) we have:

```julia
function AbstractMCMC.step(rng, model::AbstractMCMC.AbstractModel, sampler::MixtureSampler; kwargs...)
# Initialize every state.
transitions_and_states = map(sampler.components) do spl
AbstractMCMC.step(rng, model, spl; kwargs...)
end

# Sample the component to use this `step`.
i = rand(Categorical(sampler.weights))
# Extract the corresponding transition.
transition = first(transitions_and_states[i])
# Extract states.
states = map(last, transitions_and_states)
# Create new `MixtureState`.
state = MixtureState(i, states)

return transition, state
end
```

Suppose we then wanted to use this with some of the packages which implements AbstractMCMC.jl's interface, e.g. [`AdvancedMH.jl`](https://github.com/TuringLang/AdvancedMH.jl), then we'd simply have to implement `getparams` and `setparams!!`.


To use `MixtureSampler` with two samplers `sampler1` and `sampler2` from `AdvancedMH.jl` as components, we'd simply do

```julia
sampler = MixtureSampler([sampler1, sampler2], [0.1, 0.9])
transition, state = AbstractMCMC.step(rng, model, sampler)
while ...
transition, state = AbstractMCMC.step(rng, model, sampler, state)
end
```
21 changes: 21 additions & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ The `MCMCSerial` algorithm allows users to sample serially, with no thread or pr
"""
struct MCMCSerial <: AbstractMCMCEnsemble end

"""
getparams(state[; kwargs...])

Retrieve the values of parameters from the sampler's `state` as a `Vector{<:Real}`.
"""
function getparams end

"""
setparams!!(state, params)

Set the values of parameters in the sampler's `state` from a `Vector{<:Real}`.

This function should follow the `BangBang` interface: mutate `state` in-place if possible and
return the mutated `state`. Otherwise, it should return a new `state` containing the updated parameters.

Although not enforced, it should hold that `setparams!!(state, getparams(state)) == state`. In another
word, the sampler should implement a consistent transformation between its internal representation
and the vector representation of the parameter values.
"""
function setparams!! end

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
include("samplingstats.jl")
include("logging.jl")
include("interface.jl")
Expand Down
Loading