Skip to content

Commit

Permalink
Merge branch 'master' into torfjelde/extract-model-values-from-chain
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Jul 15, 2023
2 parents 929c821 + 7d312cd commit 88d641a
Show file tree
Hide file tree
Showing 32 changed files with 1,385 additions and 251 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/IntegrationTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
- master
merge_group:
types: [checks_requested]
pull_request:
branches: [master]
types: [auto_merge_enabled]

jobs:
test:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.0"
version = "0.23.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.5.3"
BangBang = "0.3"
Bijectors = "0.12.4"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
Distributions = "0.23.8, 0.24, 0.25"
Expand Down
7 changes: 7 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
using Documenter
using DynamicPPL
using DynamicPPL: AbstractPPL
# NOTE: This is necessary to ensure that if we print something from
# Distributions.jl in a doctest, then the shown value will not include
# a qualifier; that is, we don't want `Distributions.Normal{Float64}`
# but rather `Normal{Float64}`. The latter is what will then be printed
# in the doctest as run in `test/runtests.jl`, and so we need to stay
# consistent with that.
using Distributions

# Doctest setup
DocMeta.setdocmeta!(
Expand Down
38 changes: 36 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,34 @@ Similarly, one can specify with [`AbstractPPL.decondition`](@ref) that certain,
decondition
```

## Fixing and unfixing

We can also _fix_ a collection of variables in a [`Model`](@ref) to certain using [`fix`](@ref).

This might seem quite similar to the aforementioned [`condition`](@ref) and its siblings,
but they are indeed different operations:

- `condition`ed variables are considered to be _observations_, and are thus
included in the computation [`logjoint`](@ref) and [`loglikelihood`](@ref),
but not in [`logprior`](@ref).
- `fix`ed variables are considered to be _constant_, and are thus not included
in any log-probability computations.

The differences are more clearly spelled out in the docstring of [`fix`](@ref) below.

```@docs
fix
DynamicPPL.fixed
```

The difference between [`fix`](@ref) and [`condition`](@ref) is described in the docstring of [`fix`](@ref) above.

Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original meaning:

```@docs
unfix
```

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
Expand All @@ -108,6 +136,12 @@ For converting a chain into a format that can more easily be fed into a `Model`
value_iterator_from_chain
```

Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref).

```@docs
extract_priors
```

```@docs
NamedDist
```
Expand Down Expand Up @@ -212,7 +246,8 @@ DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```
DynamicPPL.reconstruct
```

#### Utils

Expand Down Expand Up @@ -326,4 +361,3 @@ dot_tilde_assume
tilde_observe
dot_tilde_observe
```

7 changes: 6 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky

using DocStringExtensions

using Random: Random
Expand Down Expand Up @@ -43,7 +45,6 @@ export AbstractVarInfo,
push!!,
empty!!,
getlogp,
resetlogp!,
setlogp!!,
acclogp!!,
resetlogp!!,
Expand Down Expand Up @@ -85,6 +86,7 @@ export AbstractVarInfo,
getmissings,
getargnames,
generated_quantities,
extract_priors,
# Samplers
Sampler,
SampleFromPrior,
Expand Down Expand Up @@ -117,6 +119,8 @@ export AbstractVarInfo,
pointwise_loglikelihoods,
condition,
decondition,
fix,
unfix,
# Convenience macros
@addlogprob!,
@submodel,
Expand Down Expand Up @@ -167,5 +171,6 @@ include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

end # module
93 changes: 93 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,99 @@ variables `x` would return
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
value is reconstructed to the correct type and shape according to `dist`.
"""
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
x_recon = reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
Return linked `val` but reconstruct before linking, if necessary.
Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
return a reconstructed value, i.e. a value of the same type and shape as expected
by `dist`.
See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
"""
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
return reconstruct_and_link(dist, val)
end

"""
invlink_and_reconstruct(dist, val)
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return invlinked and reconstructed `val`.
See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
"""
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
function invlink_and_reconstruct(dist, val)
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
end
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
return invlink_and_reconstruct(dist, val)
end

"""
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
"""
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
reconstruct_and_link(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
"""
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
invlink_and_reconstruct(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])
Invlink `x` and compute the logpdf under `dist` including correction from
the invlink-transformation.
If `x` is not provided, `getval(vi, vn)` will be used.
"""
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
end
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
# NOTE: Will this cause type-instabilities or will union-splitting save us?
f = istrans(vi, vn) ? invlink_transform(dist) : identity
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
return x, logpdf(dist, x) + logjac
end

# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
Expand Down
52 changes: 45 additions & 7 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ function contextual_isassumption(context::AbstractContext, vn)
return contextual_isassumption(NodeTrait(context), context, vn)
end
function contextual_isassumption(context::ConditionContext, vn)
if hasvalue(context, vn)
val = getvalue(context, vn)
if hasconditioned(context, vn)
val = getconditioned(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
if eltype(val) >: Missing && val === missing
return true
Expand All @@ -76,14 +76,48 @@ function contextual_isassumption(context::ConditionContext, vn)
end
end

# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`
# We might have nested contexts, e.g. `ConditionContext{.., <:PrefixContext{..., <:ConditionContext}}`
# so we defer to `childcontext` if we haven't concluded that anything yet.
return contextual_isassumption(childcontext(context), vn)
end
function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

isfixed(expr, vn) = false
isfixed(::Union{Symbol,Expr}, vn) = :($(DynamicPPL.contextual_isfixed)(__context__, $vn))

"""
contextual_isfixed(context, vn)
Return `true` if `vn` is considered fixed by `context`.
"""
contextual_isfixed(::IsLeaf, context, vn) = false
function contextual_isfixed(::IsParent, context, vn)
return contextual_isfixed(childcontext(context), vn)
end
function contextual_isfixed(context::AbstractContext, vn)
return contextual_isfixed(NodeTrait(context), context, vn)
end
function contextual_isfixed(context::PrefixContext, vn)
return contextual_isfixed(childcontext(context), prefix(context, vn))
end
function contextual_isfixed(context::FixedContext, vn)
if hasfixed(context, vn)
val = getfixed(context, vn)
# TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler?
if eltype(val) >: Missing && val === missing
return false
else
return true
end
end

# We might have nested contexts, e.g. `FixedContext{.., <:PrefixContext{..., <:FixedContext}}`
# so we defer to `childcontext` if we haven't concluded that anything yet.
return contextual_isfixed(childcontext(context), vn)
end

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -341,12 +375,14 @@ function generate_tilde(left, right)
$(AbstractPPL.drop_escape(varname(left))), $dist
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
if $(DynamicPPL.isfixed(left, vn))
$left = $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_tilde_assume(left, dist, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
$left = $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.tilde_observe!!)(
Expand Down Expand Up @@ -400,12 +436,14 @@ function generate_dot_tilde(left, right)
$(AbstractPPL.drop_escape(varname(left))), $right
)
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
if $(DynamicPPL.isfixed(left, vn))
$left .= $(DynamicPPL.getfixed_nested)(__context__, $vn)
elseif $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
$left .= $(DynamicPPL.getconditioned_nested)(__context__, $vn)
end

$value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)(
Expand Down
Loading

0 comments on commit 88d641a

Please sign in to comment.