From 5417510c886106d238120416cd3d2b79dd2eea11 Mon Sep 17 00:00:00 2001 From: Martin Kunz Date: Wed, 21 Aug 2024 22:47:47 +0200 Subject: [PATCH 1/2] feat: Allow Complex valued FFT When using the FFT algorithm, previously there was only a method that used the `rfft` and `irfft` methods which rely on getting a Real array input. Added a method and specialization that uses the the old method for Real valued arrays and a new method that calls `fft` and `ifft` for other arrays. --- src/imfilter.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/imfilter.jl b/src/imfilter.jl index 283352b..1ddd713 100644 --- a/src/imfilter.jl +++ b/src/imfilter.jl @@ -841,6 +841,11 @@ function _imfilter_fft!(r::AbstractCPU{FFT}, end function filtfft(A, krn) + B = fft(A) + B .*= conj!(fft(krn)) + ifft(B) +end +function filtfft(A::AbstractArray{AT}, krn::AbstractArray{KT}) where {AT<:Real,KT<:Real} B = rfft(A) B .*= conj!(rfft(krn)) irfft(B, length(axes(A, 1))) From 20b406a3d20ae70c50d731bc1526243afe261311 Mon Sep 17 00:00:00 2001 From: Martin Kunz Date: Sun, 20 Oct 2024 15:10:45 +0200 Subject: [PATCH 2/2] feat: `rfft` optimization for `FFT` filtering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an optimization for filtering a complex with a real array so that it allows the real array to use `rfft` instead of `fft` to save half the computation. The benchmarks do not show a speed-up of a 1/4 which would theoretically be expected but a lower speed-up. For some reason do not show lower memory allocation. ``` In [2]: A = rand(3000, 3000) 3000×3000 Matrix{Float64}: 0.285528 0.708856 0.620611 0.045052 0.330913 0.213895 0.873503 0.656959 0.0861678 0.938603 … 0.463651 0.491763 0.523761 0.987486 0.707374 0.118724 0.523314 0.278283 0.040879 0.546583 0.797982 0.419506 0.16004 0.134307 0.341135 0.767713 0.694436 0.465751 0.081541 0.636672 0.493196 0.212694 0.726784 0.137611 0.153863 0.102901 0.623923 0.285838 0.667936 0.985847 0.316288 0.496242 0.330982 0.802932 0.759763 0.393804 0.27839 0.613204 0.917279 0.488461 0.208559 0.139706 0.013461 0.29185 0.893805 0.198656 0.260889 0.541975 0.79197 0.593314 0.563877 0.668819 0.835136 0.814535 0.780967 0.402207 0.309717 0.535379 0.800072 0.261613 0.288121 0.743066 0.474665 0.104051 0.910556 0.660157 0.188497 0.434635 0.464992 0.156539 0.712517 0.229329 0.0900374 0.157132 0.0464563 0.94832 0.594556 0.239523 0.0353333 0.628295 0.974684 0.293648 0.767198 0.00595716 0.917384 ⋮ ⋮ ⋱ ⋮ 0.342483 0.849301 0.522493 0.147907 0.657853 0.316614 0.330971 0.960895 0.790472 0.570171 … 0.805825 0.575838 0.92688 0.545242 0.849474 0.553651 0.347863 0.284356 0.0710359 0.470919 0.370001 0.17453 0.757803 0.47875 0.514264 0.965828 0.161526 0.490146 0.586126 0.886021 0.426073 0.992352 0.686143 0.277794 0.93995 0.311383 0.394133 0.925497 0.159736 0.706443 0.0387112 0.219725 0.604491 0.827922 0.0177041 0.993723 0.722314 0.794828 0.661236 0.690974 0.24774 0.0865962 0.955458 0.542212 0.690087 0.779484 0.721536 0.0613569 0.584916 0.22701 0.0871801 0.285467 0.214378 0.016846 0.0159929 0.0941595 0.814224 0.194526 0.0966802 0.885505 0.315272 0.896789 0.214747 0.0933588 0.803658 0.951828 0.68916 0.891683 0.223701 0.573007 0.359542 0.151604 0.410088 0.298216 0.371781 0.609701 0.358088 0.0729099 0.142997 0.439933 0.161952 0.350227 0.326087 0.477912 0.386247 In [3]: B = rand(ComplexF32, 3000, 3000) 3000×3000 Matrix{ComplexF32}: 0.314435+0.350523im 0.31276+0.000674367im 0.403332+0.0267981im 0.836204+0.5429im … 0.352876+0.946266im 0.938128+0.477059im 0.958154+0.541663im 0.658549+0.493835im 0.208547+0.811982im 0.506163+0.888442im 0.720267+0.775364im 0.0987085+0.34869im 0.114646+0.363576im 0.0707946+0.951316im 0.897746+0.781405im 0.446996+0.679693im 0.166268+0.781975im 0.51661+0.284825im 0.488436+0.997122im 0.39242+0.633143im 0.852163+0.575789im 0.95094+0.867025im 0.963788+0.0771892im 0.810548+0.851137im 0.596596+0.24115im 0.229569+0.74279im 0.716469+0.953966im 0.471434+0.197209im 0.0622576+0.935325im 0.745592+0.505512im 0.424859+0.444605im 0.727532+0.0781438im 0.225914+0.979836im 0.988875+0.0307086im 0.686103+0.611004im 0.0149557+0.135689im 0.0653017+0.159109im 0.404576+0.134711im 0.220633+0.814599im 0.0810269+0.77369im ⋮ ⋱ 0.999065+0.361406im 0.986685+0.90664im 0.512323+0.586943im 0.790515+0.239983im … 0.58594+0.0124457im 0.608344+0.886342im 0.213773+0.25945im 0.339364+0.416441im 0.458132+0.558931im 0.36812+0.744078im 0.0747138+0.843561im 0.244537+0.998928im 0.307395+0.996972im 0.220357+0.665658im 0.794098+0.536488im 0.945331+0.81305im 0.632005+0.272713im 0.521595+0.943606im 0.684861+0.998334im 0.0401222+0.158511im 0.47845+0.813752im 0.378089+0.538594im 0.345376+0.468067im 0.867263+0.361034im 0.785968+0.501483im 0.828497+0.626334im 0.603006+0.704717im 0.413903+0.0815281im 0.728322+0.529951im 0.227673+0.328137im 0.445953+0.994137im 0.953913+0.130831im 0.938755+0.593728im 0.404837+0.707017im 0.135038+0.498083im 0.874043+0.694357im 0.687566+0.768597im 0.455904+0.358583im 0.0127361+0.115687im 0.415703+0.949933im In [24]: function full_fft(A,B) A_fft = fft(A) B_fft = fft(B) A_fft .*= conj!(B) A_fft end full_fft (generic function with 1 method) In [30]: function real_fft(A,B) A_fft = rfft(A) B_fft = fft(B) _stretch_mul(Float64, A_fft, ComplexF32, conj!(B), length(axes(A,1))) end real_fft (generic function with 1 method) In [32]: @benchmark full_fft($A,$B) BenchmarkTools.Trial: 15 samples with 1 evaluation. Range (min … max): 330.713 ms … 346.502 ms ┊ GC (min … max): 0.72% … 0.98% Time (median): 337.403 ms ┊ GC (median): 0.94% Time (mean ± σ): 338.132 ms ± 4.592 ms ┊ GC (mean ± σ): 0.81% ± 0.34% ▁ ▁ ▁ █ ▁ ▁ ▁ ▁ ▁ █ ▁ ▁ ▁ █▁▁▁▁▁▁█▁█▁▁▁█▁▁▁▁▁▁▁█▁█▁█▁▁▁▁▁▁▁▁▁▁█▁█▁▁█▁▁▁█▁▁▁█▁▁▁▁▁▁▁▁▁▁█ ▁ 331 ms Histogram: frequency by time 347 ms < Memory estimate: 343.32 MiB, allocs estimate: 19. In [33]: @benchmark real_fft($A,$B) BenchmarkTools.Trial: 18 samples with 1 evaluation. Range (min … max): 261.506 ms … 467.978 ms ┊ GC (min … max): 0.00% … 43.72% Time (median): 267.245 ms ┊ GC (median): 1.90% Time (mean ± σ): 291.212 ms ± 57.510 ms ┊ GC (mean ± σ): 9.90% ± 12.93% █ ▅█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃ ▁ 262 ms Histogram: frequency by time 468 ms < Memory estimate: 480.61 MiB, allocs estimate: 62. ``` [fixes: #275] --- src/imfilter.jl | 41 ++++++++++++++++++++++++++++------------- test/2d.jl | 21 ++++++++++++++++++--- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/src/imfilter.jl b/src/imfilter.jl index 1ddd713..4def267 100644 --- a/src/imfilter.jl +++ b/src/imfilter.jl @@ -840,23 +840,38 @@ function _imfilter_fft!(r::AbstractCPU{FFT}, out end -function filtfft(A, krn) - B = fft(A) - B .*= conj!(fft(krn)) - ifft(B) -end -function filtfft(A::AbstractArray{AT}, krn::AbstractArray{KT}) where {AT<:Real,KT<:Real} - B = rfft(A) - B .*= conj!(rfft(krn)) - irfft(B, length(axes(A, 1))) -end -function filtfft(A::AbstractArray{C}, krn) where {C<:Colorant} +# NOTE: FFT followed by IFFT can be optimized using conjugate symmetry for real arrays +@inline _fft(A::AbstractArray{T}) where {T<:Real} = rfft(A) +@inline _fft(A::AbstractArray{T}) where {T<:Complex} = fft(A) +@inline _ifft(_::Type{<:Real}, A, d::Int) = irfft(A, d) +@inline _ifft(_::Type{<:Complex}, A, _::Int) = ifft(A) +# NOTE: If for one array, the optimization is used and not for the other, the two arrays do not have the same sizes +# which needs to be dealt with in the element-wise multiplication +@inline function _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, d::Int) + A_fft[1, :] .*= B_fft[1, :] + A_fft[2:(d÷2+1), 1] .*= B_fft[2:end, 1] + A_fft[(d÷2+2):end, 1] .*= conj(reverse(B_fft[2:(end-iseven(d)), 1])) + A_fft[2:(d÷2+1), 2:end] .*= B_fft[2:end, 2:end] + A_fft[(d÷2+2):end, 2:end] .*= conj(reverse(B_fft[2:(end-iseven(d)), 2:end])) + return A_fft +end +@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, d::Int) = _stretch_mul(BT, B_fft, AT, A_fft, d) +@inline _stretch_mul(AT::Type{<:Real}, A_fft::AbstractArray, BT::Type{<:Real}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft +@inline _stretch_mul(AT::Type{<:Complex}, A_fft::AbstractArray, BT::Type{<:Complex}, B_fft::AbstractArray, _::Int) = A_fft .* B_fft +function filtfft(A::AbstractArray{ST}, krn::AbstractArray{KT}) where {ST<:Union{Real,Complex},KT<:Union{Real,Complex}} + CT = promote_type(ST, KT) + B = _fft(A) + fft_out = _stretch_mul(ST, B, KT, conj!(_fft(krn)), length(axes(A, 1))) + _ifft(CT, fft_out, length(axes(A, 1))) +end + +function filtfft(A::AbstractArray{CT}, krn) where {CT<:Colorant} Av, dims = channelview_dims(A) - kernrs = kreshape(C, krn) + kernrs = kreshape(CT, krn) B = rfft(Av, dims) B .*= conj!(rfft(kernrs, dims)) Avf = irfft(B, length(axes(Av, dims[1])), dims) - colorview(base_colorant_type(C){eltype(Avf)}, Avf) + colorview(base_colorant_type(CT){eltype(Avf)}, Avf) end channelview_dims(A::AbstractArray{C,N}) where {C<:Colorant,N} = channelview(A), ntuple(d -> d + 1, Val(N)) channelview_dims(A::AbstractArray{C,N}) where {C<:ImageCore.Color1,N} = channelview(A), ntuple(identity, Val(N)) diff --git a/test/2d.jl b/test/2d.jl index 37d8fd2..b346250 100644 --- a/test/2d.jl +++ b/test/2d.jl @@ -337,7 +337,22 @@ end @test_throws err imfilter(CPU1(), A, kern, Fill(0, (3,))) kernf = ImageFiltering.factorkernel(kern) err = DimensionMismatch("output indices (OffsetArrays.IdOffsetRange(values=0:9, indices=0:9), OffsetArrays.IdOffsetRange(values=1:8, indices=1:8)) disagree with requested indices (1:8, 0:9)") - @test_throws err imfilter(CPU1(), A, kern, Fill(0, (1,0))) - @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,1))) - @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0,0))) + @test_throws err imfilter(CPU1(), A, kern, Fill(0, (1, 0))) + @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 1))) + @test_throws DimensionMismatch imfilter(CPU1(), A, kern, Fill(0, (0, 0))) +end + +@testset "Complex FFT" begin + + A = rand(10, 10) + B = rand(10, 10) + @test filtfft(A, B) ≈ filtfft(ComplexF32.(A), B) + @test filtfft(A, B) ≈ filtfft(A, ComplexF32.(B)) + @test filtfft(A, B) ≈ filtfft(ComplexF32.(A), ComplexF32.(B)) + + C = rand(9, 9) + D = rand(9, 9) + @test filtfft(C, D) ≈ filtfft(ComplexF32.(C), D) + @test filtfft(C, D) ≈ filtfft(C, ComplexF32.(D)) + @test filtfft(C, D) ≈ filtfft(ComplexF32.(C), ComplexF32.(D)) end