Skip to content

Commit

Permalink
WIP: ForwardDiff extension
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Dec 13, 2024
1 parent b6f41d8 commit e06611d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
18 changes: 18 additions & 0 deletions ext/FFTWForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module FFTWForwardDiffExt
# AbstractFFTs.complexfloat(x::AbstractArray{<:Dual}) = float.(x .+ 0im)


plan_r2r(x::AbstractArray{<:Dual}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)
plan_r2r(x::AbstractArray{<:Complex{<:Dual}}, FLAG, dims=1:ndims(x)) = plan_r2r(dual2array(x), FLAG, 1 .+ dims)

Check warning on line 6 in ext/FFTWForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWForwardDiffExt.jl#L5-L6

Added lines #L5 - L6 were not covered by tests

for plan in (:plan_irfft, :plan_brfft) # these take an extra argument, only when complex?
@eval begin
$plan(x::AbstractArray{<:Dual}, dims=1:ndims(x)) = $plan(dual2array(x), 1 .+ dims)
$plan(x::AbstractArray{<:Complex{<:Dual}}, d::Integer, dims=1:ndims(x)) = $plan(dual2array(x), d, 1 .+ dims)

Check warning on line 11 in ext/FFTWForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWForwardDiffExt.jl#L10-L11

Added lines #L10 - L11 were not covered by tests
end
end

r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x

Check warning on line 16 in ext/FFTWForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWForwardDiffExt.jl#L15-L16

Added lines #L15 - L16 were not covered by tests

end #module
18 changes: 18 additions & 0 deletions test/fftwforwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testset "r2r" begin
x1 = Dual.(1:4.0, 2:5, 3:6)
t = FFTW.r2r(x1, FFTW.R2HC)

@test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC)

t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC)
@test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)

f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
@test derivative(f, 0.1) 1.0

@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)
end

0 comments on commit e06611d

Please sign in to comment.