-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme fails with MultiHeadAttention layer #2448
Comments
The same code run on CUDA @btime CUDA.@sync Flux.gradient(mha) do m
sum(first(m(x, x, x)))
end
11.983 ms (2583 allocations: 137.55 KiB) whereas @btime CUDA.@sync gradient_ez(mha) do m
sum(first(m($x, $x, $x)))
end
....
[2] EnzymeCreateAugmentedPrimal(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnUsed::Bool, shadowReturnUsed::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, forceAnonymousTape::Bool, width::Int64, atomicAdd::Bool)
@ Enzyme.API ~/.julia/packages/Enzyme/srACB/src/api.jl:190
[3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::NTuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:3141
[4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5074
[5] codegen
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:4481 [inlined]
[6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771
[7] _thunk
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5771 [inlined]
[8] cached_compilation
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5809 [inlined]
[9] (::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{4, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5875
[10] JuliaContext(f::Enzyme.Compiler.var"#560#561"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, NTuple{…}, Int64, Bool, Bool, UInt64, DataType}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
[11] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
[12] #s2027#559
@ ~/.julia/packages/Enzyme/srACB/src/compiler.jl:5827 [inlined] So it can be reproduced with following packages and Julia 1.10.3
Thanks! |
@mashu can you post the whole log? |
I was convinced I attached it earlier, but apparently I didn't so here it is
using Enzyme
using Flux
using CUDA
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
function gradient_ez(f, x...)
args = []
for x in x
if x isa Number
push!(args, Active(x))
else
push!(args, Duplicated(x, make_zero(x)))
end
end
ret = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)
g = ntuple(i -> x[i] isa Number ? ret[1][i] : args[i].dval, length(x))
return g
end
x = CUDA.rand(Float32, 64, 100, 512)
mha = MultiHeadAttention(64 => 64 => 64) |> gpu
Flux.gradient(mha) do m
sum(first(m(x, x, x)))
end
Δ = gradient_ez(mha) do m
sum(first(m(x, x, x)))
end |
Also does this wokr on CPU? |
@wsmoses Initially I got compilation error with CPU version, but after moving to separate project (MWE) it only fails for GPU. Having said that, I still can't figure out why it fails in my main project, as packages are up to date and basically the same version. But this GPU failure is at least reproducible. |
GPU is in progress so the report is super helpful but also presently expected. Maybe check the current versions of packages in your project and see if it's forcing an older Enzyme? |
It's the same version of |
Ah but what's your Enzyme version (rather than Enzyme_jll which is a dependncy) |
Looks the same v0.12.6 Working MWE ]st
Broken one ]st
|
Also including log with error that happens CPU side on the broken project, not sure if that helps though. |
From the log I think the simplest answer here is we should just add the attention custom derivative in nnlib. I assume there's one already for CR? If so you can try our import CR rule into enzyme macro as a test to see if anything else fails, while in the interim we can look at making a fast rule for (CR rules will be slower and come with caveats) |
@wsmoses Long story short, I wanted to use Enzyme, because I often lack skills to write rrule and there is none for MultiHeadAttention in NNlib. Longer answer is that I am using currently NeuralAttentionlib.jl which is part of Transformers.jl which has customization to layer I need and rrule that makes that variant of MHA couple of times faster on GPU. My hope was that maybe Enzyme does better job than Zygote when it comes to performance of the code it produces (when no rrule is provided). |
If you can wait a short bit (it's currently unregistered and there's a bunch of small things we should add), Reactant.jl is an execution engine (eg does tons of fancy optimizations/kernel fusion), is both Enzyme and GPU compatible out of the box, and might be what you're looking for. In the interim I'll push on the GPU support for native Enztme here too, but just throwing that out there if helpful. |
I am attaching MWE where Zygote (default of Flux) works fine but Enzyme fails compilation (@wsmoses )
The text was updated successfully, but these errors were encountered: