-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use ForwardDiff.jacobian in place of Zygote.forward_jacobian #1468
base: master
Are you sure you want to change the base?
Conversation
There are some enzyme related errors in NNlib integration tests but they seem unrelated to this PR. |
@ToucheSir, LMK I need to add more tests. here's a working MWE with Lux. This also resolves #1348 with the change in this PR, this code is working: using Random
using Lux, CUDA, LuxCUDA, ComponentArrays
using Zygote, ForwardDiff
CUDA.allowscalar(false)
#==========================#
function testhessian(
NN::Lux.AbstractExplicitLayer,
data::Tuple;
device = cpu_device(),
)
p, st = Lux.setup(Random.default_rng(), NN)
st = Lux.testmode(st)
p = ComponentArray(p)
xdata, ydata = data |> device
p, st = (p, st) |> device
function loss(optx)
ypred, _ = NN(xdata, optx, st)
sum(abs2, ydata - ypred)
end
g(p) = Zygote.gradient(loss, p)[1]
H(p) = ForwardDiff.jacobian(g, p)
Zygote.hessian(loss, p)
end
#==========================#
NN = Chain(Dense(1, 3), Dense(3, 1))
data = ntuple(_ -> rand(1, 10), 2)
device = Lux.gpu_device()
H = testhessian(NN, data; device) julia> include("hess.jl")
10×10 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
0.236781 -0.075257 -1.20583 0.31846 -0.101217 -1.62179 -0.713834 0.503548 -1.14138 1.98508
-0.075257 0.0239192 0.383253 -0.101217 0.0321702 0.515458 0.0296168 -0.780695 0.362769 -0.630924
-1.20583 0.383253 6.1408 -1.62179 0.515458 8.2591 0.474545 -2.56436 5.19194 -10.1092
0.318461 -0.101217 -1.62179 0.514738 -0.163601 -2.62135 -2.09317 0.677249 -1.53511 3.20854
-0.101217 0.0321702 0.515458 -0.163601 0.0519977 0.833151 0.0398333 -2.18309 0.487909 -1.01978
-1.62179 0.515458 8.2591 -2.62135 0.833151 13.3494 0.638242 -3.44895 5.84984 -16.3398
-0.713834 0.0296168 0.474545 -2.09317 0.0398333 0.638242 0.0366717 -0.198167 0.449183 -0.781213
0.503548 -0.780695 -2.56436 0.677249 -2.18309 -3.44895 -0.198167 1.07086 -2.4273 4.22154
-1.14138 0.362769 5.19194 -1.53511 0.487909 5.84984 0.449183 -2.4273 5.50193 -9.56889
1.98508 -0.630924 -10.1092 3.20854 -1.01978 -16.3398 -0.781213 4.22154 -9.56889 20.0
(hess) pkg> st
Status `~/.julia/dev/GeometryLearning.jl/hess/Project.toml`
[052768ef] CUDA v5.0.0
[b0b7db55] ComponentArrays v0.15.4
[f6369f11] ForwardDiff v0.10.36
[b2108857] Lux v0.5.8
[d0bbae9a] LuxCUDA v0.3.1
[e88e6eb3] Zygote v0.6.67 `~/.julia/dev/Zygote` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused. This looks like the same change as #1270, just with no tests? My comment at #1270 (comment) and @mcabbott's at #1270 (comment) still very much apply, so those need to be addressed.
Did this ever reach a conclusion? I'm in need of the ability to take the jacobian with respect to the inputs of a (Lux) model output and then optimize that object using gradient descent updates on the (Lux) model parameters. Something like the following
Or should I be looking towards JAX for this sort of thing? The use case is thermodynamics. |
That's a better question for the SciML/Lux help channels, not this issue tracker. |
This PR changes the implementation used internally for FwdDiff-over-Zygote. It didn't get much attention as it was a little unclear what this solves -- see requests above for tests which fail before the change. Your example wants to do Zygote-over-ForwardDiff, which won't work, and would not be changed by this PR. (Zygote has a rule for |
Pursuant to #1270