Skip to content

Commit

Permalink
feat: Allow Complex valued FFT (#275)
Browse files Browse the repository at this point in the history
* feat: Allow Complex valued FFT

When using the FFT algorithm, previously there was only a method that
used the `rfft` and `irfft` methods which rely on getting a Real array
input. Added a method and specialization that uses the the old method
for Real valued arrays and a new method that calls `fft` and `ifft` for
other arrays.

* feat: `rfft` optimization for `FFT` filtering

Adds an optimization for filtering a complex with a real array so that
it allows the real array to use `rfft` instead of `fft` to save half the
computation. The benchmarks do not show a speed-up of a 1/4 which would
theoretically be expected but a lower speed-up. For some reason do not
show lower memory allocation.

[fixes: #275]
  • Loading branch information
kunzaatko authored Dec 5, 2024
1 parent 66cf9d9 commit ab90e57
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
36 changes: 28 additions & 8 deletions src/imfilter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -840,18 +840,38 @@ function _imfilter_fft!(r::AbstractCPU{FFT},
out
end

function filtfft(A, krn)
B = rfft(A)
B .*= conj!(rfft(krn))
irfft(B, length(axes(A, 1)))
end
function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant}
# NOTE: FFT followed by IFFT can be optimized using conjugate symmetry for real arrays
@inline _fft(A::AbstractArray{T}) where {T<:Real} = rfft(A)
@inline _fft(A::AbstractArray{T}) where {T<:Complex} = fft(A)
@inline _ifft(_::Type{<:Real}, A, d::Int) = irfft(A, d)
@inline _ifft(_::Type{<:Complex}, A, _::Int) = ifft(A)
# NOTE: If for one array, the optimization is used and not for the other, the two arrays do not have the same sizes
# which needs to be dealt with in the element-wise multiplication
@inline function _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, d::Int)
A_fft[1, :] .*= B_fft[1, :]
A_fft[2:(d÷2+1), 1] .*= B_fft[2:end, 1]
A_fft[(d÷2+2):end, 1] .*= conj(reverse(B_fft[2:(end-iseven(d)), 1]))
A_fft[2:(d÷2+1), 2:end] .*= B_fft[2:end, 2:end]
A_fft[(d÷2+2):end, 2:end] .*= conj(reverse(B_fft[2:(end-iseven(d)), 2:end]))
return A_fft
end
@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, d::Int) = _stretch_mul(BT, B_fft, AT, A_fft, d)
@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft
@inline _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft
function filtfft(A::AbstractArray{ST}, krn::AbstractArray{KT}) where {ST<:Union{Real,Complex},KT<:Union{Real,Complex}}
CT = promote_type(ST, KT)
B = _fft(A)
fft_out = _stretch_mul(ST, B, KT, conj!(_fft(krn)), length(axes(A, 1)))
_ifft(CT, fft_out, length(axes(A, 1)))
end

function filtfft(A::AbstractArray{CT}, krn) where {CT<:Colorant}
Av, dims = channelview_dims(A)
kernrs = kreshape(C, krn)
kernrs = kreshape(CT, krn)
B = rfft(Av, dims)
B .*= conj!(rfft(kernrs, dims))
Avf = irfft(B, length(axes(Av, dims[1])), dims)
colorview(base_colorant_type(C){eltype(Avf)}, Avf)
colorview(base_colorant_type(CT){eltype(Avf)}, Avf)
end
channelview_dims(A::AbstractArray{C,N}) where {C<:Colorant,N} = channelview(A), ntuple(d -> d + 1, Val(N))
channelview_dims(A::AbstractArray{C,N}) where {C<:ImageCore.Color1,N} = channelview(A), ntuple(identity, Val(N))
Expand Down
21 changes: 18 additions & 3 deletions test/2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,22 @@ end
@test_throws err imfilter(CPU1(), A, kern, Fill(0, (3,)))
kernf = ImageFiltering.factorkernel(kern)
err = DimensionMismatch("output indices (OffsetArrays.IdOffsetRange(values=0:9, indices=0:9), OffsetArrays.IdOffsetRange(values=1:8, indices=1:8)) disagree with requested indices (1:8, 0:9)")
@test_throws err imfilter(CPU1(), A, kern, Fill(0, (1,0)))
@test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,1)))
@test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,0)))
@test_throws err imfilter(CPU1(), A, kern, Fill(0, (1, 0)))
@test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 1)))
@test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 0)))
end

@testset "Complex FFT" begin

A = rand(10, 10)
B = rand(10, 10)
@test filtfft(A, B) filtfft(ComplexF32.(A), B)
@test filtfft(A, B) filtfft(A, ComplexF32.(B))
@test filtfft(A, B) filtfft(ComplexF32.(A), ComplexF32.(B))

C = rand(9, 9)
D = rand(9, 9)
@test filtfft(C, D) filtfft(ComplexF32.(C), D)
@test filtfft(C, D) filtfft(C, ComplexF32.(D))
@test filtfft(C, D) filtfft(ComplexF32.(C), ComplexF32.(D))
end

0 comments on commit ab90e57

Please sign in to comment.