Skip to content

Commit

Permalink
fix: nested ad when using direct eval in function (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jul 3, 2024
1 parent b319371 commit 91e74b1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.5.58"
version = "0.5.59"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/helpers/stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ end
function CRC.rrule(::typeof(getproperty), s::StatefulLuxLayer, name::Symbol)
y = getproperty(s, name)
∇getproperty = @closure Δ -> begin
name === :ps && return NoTangent(), (; ps=Δ), NoTangent()
name === :ps && return NoTangent(), CRC.Tangent{typeof(s)}(; ps=Δ), NoTangent()
return NoTangent(), NoTangent(), NoTangent()
end
return y, ∇getproperty
Expand Down
25 changes: 25 additions & 0 deletions test/helpers/nestedad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,28 @@ end

@test_nowarn jacobian_vector_product(ftest, AutoForwardDiff(), nt, u)
end

@testitem "Nested AD: Issue #743 (eval + gradient)" setup=[SharedTestSetup] tags=[:autodiff] begin
using Zygote, Optimisers, Random, ForwardDiff, ComponentArrays

function loss_function(model, ps, st, x)
smodel = StatefulLuxLayer{true}(model, ps, st)
y_pred = smodel(x)
dy_pred = only(Zygote.gradient(sum smodel, x))
loss = sum(dy_pred .+ y_pred .^ 2 / 2)
return loss
end

rng = StableRNG(1234)
model = Chain(Dense(1 => 8, sigmoid), Dense(8 => 1))
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 1, 12)

_, ∂ps, _, ∂x = Zygote.gradient(loss_function, model, ps, st, x)

∂ps_fd = ForwardDiff.gradient(ps -> loss_function(model, ps, st, x), ComponentArray(ps))
∂x_fd = ForwardDiff.gradient(x -> loss_function(model, ps, st, x), x)

@test ComponentArray(∂ps)∂ps_fd rtol=1e-3 atol=1e-3
@test ∂x∂x_fd rtol=1e-3 atol=1e-3
end

1 comment on commit 91e74b1

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 91e74b1 Previous: 0a3aaa8 Ratio
Dense(2 => 2)/cpu/reverse/ReverseDiff (compiled)/(2, 128) 3685.75 ns 3656.875 ns 1.01
Dense(2 => 2)/cpu/reverse/Zygote/(2, 128) 7148.5 ns 7169.4 ns 1.00
Dense(2 => 2)/cpu/reverse/Tracker/(2, 128) 20699 ns 21440 ns 0.97
Dense(2 => 2)/cpu/reverse/ReverseDiff/(2, 128) 9712.2 ns 9774.2 ns 0.99
Dense(2 => 2)/cpu/reverse/Flux/(2, 128) 9162.25 ns 8976.8 ns 1.02
Dense(2 => 2)/cpu/reverse/SimpleChains/(2, 128) 4474.75 ns 4495.875 ns 1.00
Dense(2 => 2)/cpu/reverse/Enzyme/(2, 128) 1154.0357142857142 ns 1164.6814814814816 ns 0.99
Dense(2 => 2)/cpu/forward/NamedTuple/(2, 128) 1125.087837837838 ns 1121.9607843137255 ns 1.00
Dense(2 => 2)/cpu/forward/ComponentArray/(2, 128) 1191.7015503875969 ns 1179.7481481481482 ns 1.01
Dense(2 => 2)/cpu/forward/Flux/(2, 128) 1810.909090909091 ns 1782 ns 1.02
Dense(2 => 2)/cpu/forward/SimpleChains/(2, 128) 180.22708039492244 ns 180.18014184397163 ns 1.00
Dense(20 => 20)/cpu/reverse/ReverseDiff (compiled)/(20, 128) 17383 ns 17202 ns 1.01
Dense(20 => 20)/cpu/reverse/Zygote/(20, 128) 16872 ns 17022 ns 0.99
Dense(20 => 20)/cpu/reverse/Tracker/(20, 128) 39133.5 ns 37590 ns 1.04
Dense(20 => 20)/cpu/reverse/ReverseDiff/(20, 128) 29375 ns 29355 ns 1.00
Dense(20 => 20)/cpu/reverse/Flux/(20, 128) 20198 ns 19917 ns 1.01
Dense(20 => 20)/cpu/reverse/SimpleChains/(20, 128) 17558 ns 17322 ns 1.01
Dense(20 => 20)/cpu/reverse/Enzyme/(20, 128) 4326.571428571428 ns 4358.142857142857 ns 0.99
Dense(20 => 20)/cpu/forward/NamedTuple/(20, 128) 3836 ns 3886 ns 0.99
Dense(20 => 20)/cpu/forward/ComponentArray/(20, 128) 3948.625 ns 3954.875 ns 1.00
Dense(20 => 20)/cpu/forward/Flux/(20, 128) 4900.714285714285 ns 4926.285714285715 ns 0.99
Dense(20 => 20)/cpu/forward/SimpleChains/(20, 128) 1667.1 ns 1658.1 ns 1.01
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 3, 128) 41632816 ns 39601602.5 ns 1.05
Conv((3, 3), 3 => 3)/cpu/reverse/Zygote/(64, 64, 3, 128) 57885427 ns 58038779 ns 1.00
Conv((3, 3), 3 => 3)/cpu/reverse/Tracker/(64, 64, 3, 128) 78862627 ns 76782216 ns 1.03
Conv((3, 3), 3 => 3)/cpu/reverse/ReverseDiff/(64, 64, 3, 128) 84887285 ns 87946168 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Flux/(64, 64, 3, 128) 78953692 ns 73096588.5 ns 1.08
Conv((3, 3), 3 => 3)/cpu/reverse/SimpleChains/(64, 64, 3, 128) 11935490 ns 12291038 ns 0.97
Conv((3, 3), 3 => 3)/cpu/reverse/Enzyme/(64, 64, 3, 128) 18024088 ns 17932809 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/NamedTuple/(64, 64, 3, 128) 7085466 ns 7036947 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/ComponentArray/(64, 64, 3, 128) 7046918 ns 6997914 ns 1.01
Conv((3, 3), 3 => 3)/cpu/forward/Flux/(64, 64, 3, 128) 11692974 ns 10260008 ns 1.14
Conv((3, 3), 3 => 3)/cpu/forward/SimpleChains/(64, 64, 3, 128) 6448135 ns 6396686 ns 1.01
vgg16/cpu/reverse/Zygote/(32, 32, 3, 16) 741248602 ns 745149713 ns 0.99
vgg16/cpu/reverse/Zygote/(32, 32, 3, 64) 2567595466 ns 2558261206 ns 1.00
vgg16/cpu/reverse/Zygote/(32, 32, 3, 2) 147817513.5 ns 144356291 ns 1.02
vgg16/cpu/reverse/Tracker/(32, 32, 3, 16) 916051072.5 ns 790487012 ns 1.16
vgg16/cpu/reverse/Tracker/(32, 32, 3, 64) 3341104801 ns 2945619535 ns 1.13
vgg16/cpu/reverse/Tracker/(32, 32, 3, 2) 218429947.5 ns 202413420.5 ns 1.08
vgg16/cpu/reverse/Flux/(32, 32, 3, 16) 781576790 ns 649182419 ns 1.20
vgg16/cpu/reverse/Flux/(32, 32, 3, 64) 2857754077 ns 2806215597.5 ns 1.02
vgg16/cpu/reverse/Flux/(32, 32, 3, 2) 127855610 ns 123733562 ns 1.03
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 16) 173040428.5 ns 173997599 ns 0.99
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 64) 647207749.5 ns 647829882.5 ns 1.00
vgg16/cpu/forward/NamedTuple/(32, 32, 3, 2) 45528949 ns 34347535 ns 1.33
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 16) 165399408.5 ns 164824067 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 64) 644554490 ns 644089958 ns 1.00
vgg16/cpu/forward/ComponentArray/(32, 32, 3, 2) 29943548 ns 30272455.5 ns 0.99
vgg16/cpu/forward/Flux/(32, 32, 3, 16) 213344862 ns 185792984 ns 1.15
vgg16/cpu/forward/Flux/(32, 32, 3, 64) 769117397.5 ns 757568425 ns 1.02
vgg16/cpu/forward/Flux/(32, 32, 3, 2) 36771569 ns 35418144 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 64, 128) 1243936673.5 ns 1212967132 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/Zygote/(64, 64, 64, 128) 1888739190 ns 1867896074 ns 1.01
Conv((3, 3), 64 => 64)/cpu/reverse/Tracker/(64, 64, 64, 128) 2444598607 ns 2380054521 ns 1.03
Conv((3, 3), 64 => 64)/cpu/reverse/ReverseDiff/(64, 64, 64, 128) 2567167786 ns 2478572600 ns 1.04
Conv((3, 3), 64 => 64)/cpu/reverse/Flux/(64, 64, 64, 128) 1870827117 ns 1781654425.5 ns 1.05
Conv((3, 3), 64 => 64)/cpu/reverse/Enzyme/(64, 64, 64, 128) 566986727 ns 566970305 ns 1.00
Conv((3, 3), 64 => 64)/cpu/forward/NamedTuple/(64, 64, 64, 128) 329803795 ns 325256461 ns 1.01
Conv((3, 3), 64 => 64)/cpu/forward/ComponentArray/(64, 64, 64, 128) 327658424 ns 322487937 ns 1.02
Conv((3, 3), 64 => 64)/cpu/forward/Flux/(64, 64, 64, 128) 461879142.5 ns 444562040 ns 1.04
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 1, 128) 12019081 ns 12022963.5 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Zygote/(64, 64, 1, 128) 17861815 ns 17911867 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Tracker/(64, 64, 1, 128) 19225164.5 ns 19149422 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/ReverseDiff/(64, 64, 1, 128) 23898726 ns 23910762 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/Flux/(64, 64, 1, 128) 17926150.5 ns 17963694 ns 1.00
Conv((3, 3), 1 => 1)/cpu/reverse/SimpleChains/(64, 64, 1, 128) 1170420 ns 1203966.5 ns 0.97
Conv((3, 3), 1 => 1)/cpu/reverse/Enzyme/(64, 64, 1, 128) 5897461 ns 5913573 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/NamedTuple/(64, 64, 1, 128) 2063880 ns 2052656 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/ComponentArray/(64, 64, 1, 128) 2048965 ns 2043938 ns 1.00
Conv((3, 3), 1 => 1)/cpu/forward/Flux/(64, 64, 1, 128) 2087170 ns 2075728 ns 1.01
Conv((3, 3), 1 => 1)/cpu/forward/SimpleChains/(64, 64, 1, 128) 209221 ns 212601 ns 0.98
Dense(200 => 200)/cpu/reverse/ReverseDiff (compiled)/(200, 128) 295356 ns 293096 ns 1.01
Dense(200 => 200)/cpu/reverse/Zygote/(200, 128) 268615 ns 265103 ns 1.01
Dense(200 => 200)/cpu/reverse/Tracker/(200, 128) 368925 ns 363096 ns 1.02
Dense(200 => 200)/cpu/reverse/ReverseDiff/(200, 128) 410503 ns 407228.5 ns 1.01
Dense(200 => 200)/cpu/reverse/Flux/(200, 128) 275723 ns 274961.5 ns 1.00
Dense(200 => 200)/cpu/reverse/SimpleChains/(200, 128) 409250 ns 411922 ns 0.99
Dense(200 => 200)/cpu/reverse/Enzyme/(200, 128) 83457.5 ns 83845 ns 1.00
Dense(200 => 200)/cpu/forward/NamedTuple/(200, 128) 81958 ns 81722 ns 1.00
Dense(200 => 200)/cpu/forward/ComponentArray/(200, 128) 83466 ns 81672 ns 1.02
Dense(200 => 200)/cpu/forward/Flux/(200, 128) 87272 ns 86912 ns 1.00
Dense(200 => 200)/cpu/forward/SimpleChains/(200, 128) 104607 ns 104544 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff (compiled)/(64, 64, 16, 128) 201257753 ns 189224051 ns 1.06
Conv((3, 3), 16 => 16)/cpu/reverse/Zygote/(64, 64, 16, 128) 327080609 ns 326445067 ns 1.00
Conv((3, 3), 16 => 16)/cpu/reverse/Tracker/(64, 64, 16, 128) 425910289.5 ns 393391322 ns 1.08
Conv((3, 3), 16 => 16)/cpu/reverse/ReverseDiff/(64, 64, 16, 128) 475827775.5 ns 458551247 ns 1.04
Conv((3, 3), 16 => 16)/cpu/reverse/Flux/(64, 64, 16, 128) 387637862 ns 370604679 ns 1.05
Conv((3, 3), 16 => 16)/cpu/reverse/SimpleChains/(64, 64, 16, 128) 337223270 ns 351721620 ns 0.96
Conv((3, 3), 16 => 16)/cpu/reverse/Enzyme/(64, 64, 16, 128) 101307461.5 ns 100845783.5 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/NamedTuple/(64, 64, 16, 128) 43845028.5 ns 43799088 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/ComponentArray/(64, 64, 16, 128) 43893493 ns 43801954 ns 1.00
Conv((3, 3), 16 => 16)/cpu/forward/Flux/(64, 64, 16, 128) 56749581 ns 59997614 ns 0.95
Conv((3, 3), 16 => 16)/cpu/forward/SimpleChains/(64, 64, 16, 128) 29066458 ns 29600942 ns 0.98
Dense(2000 => 2000)/cpu/reverse/ReverseDiff (compiled)/(2000, 128) 19174288 ns 19020152 ns 1.01
Dense(2000 => 2000)/cpu/reverse/Zygote/(2000, 128) 19690891 ns 19737206 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Tracker/(2000, 128) 23502746 ns 23831711 ns 0.99
Dense(2000 => 2000)/cpu/reverse/ReverseDiff/(2000, 128) 24277302 ns 24363296 ns 1.00
Dense(2000 => 2000)/cpu/reverse/Flux/(2000, 128) 19749481 ns 19887751 ns 0.99
Dense(2000 => 2000)/cpu/reverse/Enzyme/(2000, 128) 6575644 ns 6645866 ns 0.99
Dense(2000 => 2000)/cpu/forward/NamedTuple/(2000, 128) 6558677 ns 6603096 ns 0.99
Dense(2000 => 2000)/cpu/forward/ComponentArray/(2000, 128) 6573677.5 ns 6599702 ns 1.00
Dense(2000 => 2000)/cpu/forward/Flux/(2000, 128) 6685988 ns 6597575 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.