diff --git a/src/imfilter.jl b/src/imfilter.jl index 283352b..4def267 100644 --- a/src/imfilter.jl +++ b/src/imfilter.jl @@ -840,18 +840,38 @@ function _imfilter_fft!(r::AbstractCPU{FFT}, out end -function filtfft(A, krn) - B = rfft(A) - B .*= conj!(rfft(krn)) - irfft(B, length(axes(A, 1))) -end -function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant} +# NOTE: FFT followed by IFFT can be optimized using conjugate symmetry for real arrays +@inline _fft(A::AbstractArray{T}) where {T<:Real} = rfft(A) +@inline _fft(A::AbstractArray{T}) where {T<:Complex} = fft(A) +@inline _ifft(_::Type{<:Real}, A, d::Int) = irfft(A, d) +@inline _ifft(_::Type{<:Complex}, A, _::Int) = ifft(A) +# NOTE: If for one array, the optimization is used and not for the other, the two arrays do not have the same sizes +# which needs to be dealt with in the element-wise multiplication +@inline function _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, d::Int) + A_fft[1, :] .*= B_fft[1, :] + A_fft[2:(d÷2+1), 1] .*= B_fft[2:end, 1] + A_fft[(d÷2+2):end, 1] .*= conj(reverse(B_fft[2:(end-iseven(d)), 1])) + A_fft[2:(d÷2+1), 2:end] .*= B_fft[2:end, 2:end] + A_fft[(d÷2+2):end, 2:end] .*= conj(reverse(B_fft[2:(end-iseven(d)), 2:end])) + return A_fft +end +@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, d::Int) = _stretch_mul(BT, B_fft, AT, A_fft, d) +@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft +@inline _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft +function filtfft(A::AbstractArray{ST}, krn::AbstractArray{KT}) where {ST<:Union{Real,Complex},KT<:Union{Real,Complex}} + CT = promote_type(ST, KT) + B = _fft(A) + fft_out = _stretch_mul(ST, B, KT, conj!(_fft(krn)), length(axes(A, 1))) + _ifft(CT, fft_out, length(axes(A, 1))) +end + +function filtfft(A::AbstractArray{CT}, krn) where {CT<:Colorant} Av, dims = channelview_dims(A) - kernrs = kreshape(C, krn) + kernrs = kreshape(CT, krn) B = rfft(Av, dims) B .*= conj!(rfft(kernrs, dims)) Avf = irfft(B, length(axes(Av, dims[1])), dims) - colorview(base_colorant_type(C){eltype(Avf)}, Avf) + colorview(base_colorant_type(CT){eltype(Avf)}, Avf) end channelview_dims(A::AbstractArray{C,N}) where {C<:Colorant,N} = channelview(A), ntuple(d -> d + 1, Val(N)) channelview_dims(A::AbstractArray{C,N}) where {C<:ImageCore.Color1,N} = channelview(A), ntuple(identity, Val(N)) diff --git a/test/2d.jl b/test/2d.jl index 37d8fd2..b346250 100644 --- a/test/2d.jl +++ b/test/2d.jl @@ -337,7 +337,22 @@ end @test_throws err imfilter(CPU1(), A, kern, Fill(0, (3,))) kernf = ImageFiltering.factorkernel(kern) err = DimensionMismatch("output indices (OffsetArrays.IdOffsetRange(values=0:9, indices=0:9), OffsetArrays.IdOffsetRange(values=1:8, indices=1:8)) disagree with requested indices (1:8, 0:9)") - @test_throws err imfilter(CPU1(), A, kern, Fill(0, (1,0))) - @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,1))) - @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,0))) + @test_throws err imfilter(CPU1(), A, kern, Fill(0, (1, 0))) + @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 1))) + @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 0))) +end + +@testset "Complex FFT" begin + + A = rand(10, 10) + B = rand(10, 10) + @test filtfft(A, B) ≈ filtfft(ComplexF32.(A), B) + @test filtfft(A, B) ≈ filtfft(A, ComplexF32.(B)) + @test filtfft(A, B) ≈ filtfft(ComplexF32.(A), ComplexF32.(B)) + + C = rand(9, 9) + D = rand(9, 9) + @test filtfft(C, D) ≈ filtfft(ComplexF32.(C), D) + @test filtfft(C, D) ≈ filtfft(C, ComplexF32.(D)) + @test filtfft(C, D) ≈ filtfft(ComplexF32.(C), ComplexF32.(D)) end