diff --git a/Project.toml b/Project.toml index fa49a335..f35b1df4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.6.2" +version = "0.6.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains" AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq" [compat] -AbstractMCMC = "5" +AbstractMCMC = "5.6" ArgCheck = "1, 2" CUDA = "3, 4, 5" DocStringExtensions = "0.8, 0.9" diff --git a/research/tests/runtests.jl b/research/tests/runtests.jl index 803458c6..2bb8e8d3 100644 --- a/research/tests/runtests.jl +++ b/research/tests/runtests.jl @@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl") include("relativistic_hmc.jl") include("riemannian_hmc.jl") -@main function runtests(patterns...; dry::Bool = false) +Comonicon.@main function runtests(patterns...; dry::Bool = false) retest(patterns...; dry = dry, verbose = Inf) end diff --git a/src/abstractmcmc.jl b/src/abstractmcmc.jl index 24ce799c..9ebedabe 100644 --- a/src/abstractmcmc.jl +++ b/src/abstractmcmc.jl @@ -30,6 +30,20 @@ getadaptor(state::HMCState) = state.adaptor getmetric(state::HMCState) = state.metric getintegrator(state::HMCState) = state.κ.τ.integrator +function AbstractMCMC.getparams(state::HMCState) + return state.transition.z.θ +end + +function AbstractMCMC.setparams!!(model, state::HMCState, params) + hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model) + return Setfield.@set state.transition.z = AdvancedHMC.phasepoint( + hamiltonian, + params, + state.transition.z.r; + ℓκ = state.transition.z.ℓκ, + ) +end + """ $(TYPEDSIGNATURES) diff --git a/test/abstractmcmc.jl b/test/abstractmcmc.jl index 25359cd6..e2355f77 100644 --- a/test/abstractmcmc.jl +++ b/test/abstractmcmc.jl @@ -21,6 +21,18 @@ using Statistics: mean LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo), ) + @testset "getparams and setparams!!" begin + t, s = AbstractMCMC.step(rng, model, nuts;) + + θ = AbstractMCMC.getparams(s) + @test θ == t.z.θ + @test AbstractMCMC.setparams!!(model, s, θ) == s + + new_θ = randn(rng, 2) + new_state = AbstractMCMC.setparams!!(model, s, new_θ) + @test AbstractMCMC.getparams(new_state) == new_θ + end + samples_nuts = AbstractMCMC.sample( rng, model,