Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add getparams and setparams!! following AbstractMCMC v5.5 and v5.6 #378

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

sunxd3
Copy link

@sunxd3 sunxd3 commented Oct 21, 2024

sunxd3 and others added 4 commits October 21, 2024 07:17
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3
Copy link
Author

sunxd3 commented Oct 22, 2024

the failed test is not related to this PR, @yebai maybe you already have an answer to why it fails?

@yebai
Copy link
Member

yebai commented Oct 22, 2024

@sunxd3 This is due to some changes in Julia’s handling of script entry functions. Can you try to fix that?

yebai
yebai previously approved these changes Oct 22, 2024
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current impl is going to lead to incorrect log-density values being used in the transition. IMO we shouuld do the safe thing by default here.

return state.transition.z.θ
end

function AbstractMCMC.setparams!!(state::HMCState, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One dangerous aspect there is that state.transition caches the logjoint and gradient computations. So when you use the @set macro here, you're going to also keep the cached log-joint and gradient computation, which will then be out of sync with the parameters.

If you then naively pass this transition somewhere, say, into the next step call, AHMC.jl will use the incorrect logjoint eval in the MH step.

IMO the safe way is to use the explicit constructor of PhasePoint I believe without passing in the cached values. This should result in receomputation of this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just realized that model is not passed into setparams!!, I think for now, we only set the parameters, then when use the transition, logp should be recomputed. (we can also later introduce some like setlogp or compute_logp etc.) I'll add a comment in the code

Copy link
Member

@yebai yebai Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed with @torfjelde's concern. Can we add logdensitymodel to

  • setparams!!(state, logdensitymodel, params)
  • getparam(state, logdensitymodel).

where logdensitymodel follows the LogDensityProblem interface. This would allow us to recompute the model's log probability (i.e., recompute_logp) inside these setparam!! functions on demand, which Turing's new Gibbs sampler uses.

sunxd3 and others added 2 commits October 23, 2024 02:50
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/abstractmcmc.jl Outdated Show resolved Hide resolved
sunxd3 and others added 2 commits October 22, 2024 21:25
Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@sunxd3
Copy link
Author

sunxd3 commented Oct 23, 2024

let's delay a bit to see through TuringLang/AbstractMCMC.jl#150

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix to setparams!! doesn't work unfortauntely. We need access to the model here, which we don't have.

IMO we should add model as an argument to the AbstractMCMC.jl interface. Otherwise there's no way we can do stuff like update the gradient information, which is an issue we're facing both here and in AdvancedMH.jl.

function AbstractMCMC.setparams!!(state::HMCState, θ)
return @set state.transition.z.θ = θ
function AbstractMCMC.setparams!!(state::HMCState, params)
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model isn't defined here 😕

@sunxd3 sunxd3 changed the title Add getparams and setparams!! following AbstractMCMC v5.5 Add getparams and setparams!! following AbstractMCMC v5.5 and v5.6 Oct 28, 2024
sunxd3 and others added 4 commits October 28, 2024 05:36
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3 sunxd3 self-assigned this Oct 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants