Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getparams and setparams!! following AbstractMCMC v5.5 and v5.6 #378

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedHMC"
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
version = "0.6.2"
version = "0.6.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.5"
ArgCheck = "1, 2"
CUDA = "3, 4, 5"
DocStringExtensions = "0.8, 0.9"
Expand Down
9 changes: 9 additions & 0 deletions src/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ getadaptor(state::HMCState) = state.adaptor
getmetric(state::HMCState) = state.metric
getintegrator(state::HMCState) = state.κ.τ.integrator

function AbstractMCMC.getparams(state::HMCState)
# TODO(sunxd): should we return a copy?
return state.transition.z.θ
end

function AbstractMCMC.setparams!!(state::HMCState, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One dangerous aspect there is that state.transition caches the logjoint and gradient computations. So when you use the @set macro here, you're going to also keep the cached log-joint and gradient computation, which will then be out of sync with the parameters.

If you then naively pass this transition somewhere, say, into the next step call, AHMC.jl will use the incorrect logjoint eval in the MH step.

IMO the safe way is to use the explicit constructor of PhasePoint I believe without passing in the cached values. This should result in receomputation of this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that model is not passed into setparams!!, I think for now, we only set the parameters, then when use the transition, logp should be recomputed. (we can also later introduce some like setlogp or compute_logp etc.) I'll add a comment in the code

Copy link
Member

@yebai yebai Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @torfjelde's concern. Can we add logdensitymodel to

  • setparams!!(state, logdensitymodel, params)
  • getparam(state, logdensitymodel).

where logdensitymodel follows the LogDensityProblem interface. This would allow us to recompute the model's log probability (i.e., recompute_logp) inside these setparam!! functions on demand, which Turing's new Gibbs sampler uses.

return @set state.transition.z.θ = θ
sunxd3 marked this conversation as resolved.
Show resolved Hide resolved
end

"""
$(TYPEDSIGNATURES)

Expand Down
12 changes: 12 additions & 0 deletions test/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ using Statistics: mean
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo),
)

@testset "getparams and setparams!!" begin
t, s = AbstractMCMC.step(rng, model, nuts;)

θ = AbstractMCMC.getparams(s)
@test θ == t.z.θ
@test AbstractMCMC.setparams!!(s, θ) == s

new_θ = randn(rng, 2)
new_state = AbstractMCMC.setparams!!(s, new_θ)
@test AbstractMCMC.getparams(new_state) == new_θ
end

samples_nuts = AbstractMCMC.sample(
rng,
model,
Expand Down
Loading