Skip to content

Commit

Permalink
Remove check on Julia>1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Oct 26, 2024
1 parent d4029ab commit a24140a
Showing 1 changed file with 53 additions and 55 deletions.
108 changes: 53 additions & 55 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,70 +17,68 @@ using Test: @test, @test_throws, @testset
using Turing

@testset "Testing inference.jl with $adbackend" for adbackend in ADUtils.adbackends
# Only test threading if 1.3+.
if VERSION > v"1.2"
@testset "threaded sampling" begin
# Test that chains with the same seed will sample identically.
@testset "rng" begin
model = gdemo_default

# multithreaded sampling with PG causes segfaults on Julia 1.5.4
# https://github.com/TuringLang/Turing.jl/issues/1571
samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0"
(
HMC(0.1, 7; adtype=adbackend),
PG(10),
IS(),
MH(),
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
else
(
HMC(0.1, 7; adtype=adbackend),
IS(),
MH(),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
end
for sampler in samplers
Random.seed!(5)
chain1 = sample(model, sampler, MCMCThreads(), 1000, 4)
@testset "threaded sampling" begin
# Test that chains with the same seed will sample identically.
@testset "rng" begin
model = gdemo_default

# multithreaded sampling with PG causes segfaults on Julia 1.5.4
# https://github.com/TuringLang/Turing.jl/issues/1571
samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0"
(
HMC(0.1, 7; adtype=adbackend),
PG(10),
IS(),
MH(),
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
else
(
HMC(0.1, 7; adtype=adbackend),
IS(),
MH(),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
end
for sampler in samplers
Random.seed!(5)
chain1 = sample(model, sampler, MCMCThreads(), 1000, 4)

Random.seed!(5)
chain2 = sample(model, sampler, MCMCThreads(), 1000, 4)
Random.seed!(5)
chain2 = sample(model, sampler, MCMCThreads(), 1000, 4)

@test chain1.value == chain2.value
end
@test chain1.value == chain2.value
end

# Should also be stable with am explicit RNG
seed = 5
rng = Random.MersenneTwister(seed)
for sampler in samplers
Random.seed!(rng, seed)
chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
# Should also be stable with am explicit RNG
seed = 5
rng = Random.MersenneTwister(seed)
for sampler in samplers
Random.seed!(rng, seed)
chain1 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)

Random.seed!(rng, seed)
chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)
Random.seed!(rng, seed)
chain2 = sample(rng, model, sampler, MCMCThreads(), 1000, 4)

@test chain1.value == chain2.value
end
@test chain1.value == chain2.value
end
end

# Smoke test for default sample call.
Random.seed!(100)
chain = sample(
gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4
)
check_gdemo(chain)
# Smoke test for default sample call.
Random.seed!(100)
chain = sample(
gdemo_default, HMC(0.1, 7; adtype=adbackend), MCMCThreads(), 1000, 4
)
check_gdemo(chain)

# run sampler: progress logging should be disabled and
# it should return a Chains object
sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default)
chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4)
@test chains isa MCMCChains.Chains
end
# run sampler: progress logging should be disabled and
# it should return a Chains object
sampler = Sampler(HMC(0.1, 7; adtype=adbackend), gdemo_default)
chains = sample(gdemo_default, sampler, MCMCThreads(), 1000, 4)
@test chains isa MCMCChains.Chains
end

@testset "chain save/resume" begin
Random.seed!(1234)

Expand Down

0 comments on commit a24140a

Please sign in to comment.