diff --git a/src/TimedOperators.jl b/src/TimedOperators.jl index f3a3c14c..f7afbd36 100644 --- a/src/TimedOperators.jl +++ b/src/TimedOperators.jl @@ -11,15 +11,6 @@ mutable struct TimedLinearOperator{T, OP <: AbstractLinearOperator{T}, F, Ft, Fc ctprod!::Fct end -TimedLinearOperator{T}( - timer::TimerOutput, - op::AbstractLinearOperator{T}, - prod!::F, - tprod!::Ft, - ctprod!::Fct, -) where {T, F, Ft, Fct} = - TimedLinearOperator{T, typeof(op), F, Ft, Fct}(timer, op, prod!, tprod!, ctprod!) - """ TimedLinearOperator(op) Creates a linear operator instrumented with timers from TimerOutputs. @@ -29,7 +20,7 @@ function TimedLinearOperator(op::AbstractLinearOperator{T}) where {T} prod!(res, x, α, β) = @timeit timer "prod" op.prod!(res, x, α, β) tprod!(res, x, α, β) = @timeit timer "tprod" op.tprod!(res, x, α, β) ctprod!(res, x, α, β) = @timeit timer "ctprod" op.ctprod!(res, x, α, β) - TimedLinearOperator{T}(timer, op, prod!, tprod!, ctprod!) + TimedLinearOperator(timer, op, prod!, tprod!, ctprod!) end TimedLinearOperator(op::AdjointLinearOperator) = adjoint(TimedLinearOperator(op.parent)) diff --git a/src/abstract.jl b/src/abstract.jl index 4e4fb908..1c41b882 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -2,6 +2,7 @@ export AbstractLinearOperator, AbstractQuasiNewtonOperator, AbstractDiagonalQuasiNewtonOperator, LinearOperator, + LinearOperator5, LinearOperatorException, hermitian, ishermitian, @@ -61,7 +62,8 @@ mutable struct LinearOperator{T, I <: Integer, F, Ft, Fct, S} <: AbstractLinearO allocated5::Bool # true for 5-args mul!, false for 3-args mul! until the vectors are allocated end -function LinearOperator{T}( +function LinearOperator( + ::Type{T}, nrow::I, ncol::I, symmetric::Bool, @@ -75,11 +77,9 @@ function LinearOperator{T}( S::DataType = Vector{T}, ) where {T, I <: Integer, F, Ft, Fct} Mv5, Mtu5 = S(undef, 0), S(undef, 0) - nargs = get_nargs(prod!) - args5 = (nargs == 4) - (args5 == false) || (nargs != 2) || throw(LinearOperatorException("Invalid number of arguments")) - allocated5 = args5 ? true : false - use_prod5! = args5 ? true : false + args5 = false + allocated5 = false + use_prod5! = false return LinearOperator{T, I, F, Ft, Fct, S}( nrow, ncol, @@ -99,21 +99,46 @@ function LinearOperator{T}( ) end -LinearOperator{T}( +function LinearOperator5( + ::Type{T}, nrow::I, ncol::I, symmetric::Bool, hermitian::Bool, - prod!, - tprod!, - ctprod!; + prod!::F, + tprod!::Ft, + ctprod!::Fct, + nprod::I, + ntprod::I, + nctprod::I; S::DataType = Vector{T}, -) where {T, I <: Integer} = - LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S) +) where {T, I <: Integer, F, Ft, Fct} + Mv5, Mtu5 = S(undef, 0), S(undef, 0) + args5 = true + allocated5 = true + use_prod5! = true + return LinearOperator{T, I, F, Ft, Fct, S}( + nrow, + ncol, + symmetric, + hermitian, + prod!, + tprod!, + ctprod!, + nprod, + ntprod, + nctprod, + args5, + use_prod5!, + Mv5, + Mtu5, + allocated5, + ) +end # create operator from other operators with +, *, vcat,... function CompositeLinearOperator( - T::DataType, + ::Type{T}, nrow::I, ncol::I, symmetric::Bool, @@ -123,7 +148,7 @@ function CompositeLinearOperator( ctprod!::Fct, args5::Bool; S::DataType = Vector{T}, -) where {I <: Integer, F, Ft, Fct} +) where {T, I <: Integer, F, Ft, Fct} Mv5, Mtu5 = S(undef, 0), S(undef, 0) allocated5 = true use_prod5! = true diff --git a/src/constructors.jl b/src/constructors.jl index 41030e3e..a4c138a9 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -16,7 +16,7 @@ function LinearOperator( prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β) tprod! = @closure (res, u, α, β) -> mul!(res, transpose(M), u, α, β) ctprod! = @closure (res, w, α, β) -> mul!(res, adjoint(M), w, α, β) - LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S) + LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S) end """ @@ -58,6 +58,47 @@ end [tprod!=nothing, ctprod!=nothing], S = Vector{T}) where {T} +Construct a linear operator from functions where the type is specified as the first argument. +Change `S` to use LinearOperators on GPU. +``` +A = rand(2, 2) +op = LinearOperator(Float64, 2, 2, false, false, + (res, v) -> mul!(res, A, v), + (res, w) -> mul!(res, A', w)) +``` + +Notice that the linear operator does not enforce the type, so using a wrong type can +result in errors. For instance, +``` +A = [im 1.0; 0.0 1.0] # Complex matrix +op = LinearOperator5(Float64, 2, 2, false, false, + (res, v) -> mul!(res, A, v), + (res, u) -> mul!(res, transpose(A), u), + (res, w) -> mul!(res, A', w)) +Matrix(op) # InexactError +``` +The error is caused because `Matrix(op)` tries to create a Float64 matrix with the +contents of the complex matrix `A`. +""" +function LinearOperator( + ::Type{T}, + nrow::I, + ncol::I, + symmetric::Bool, + hermitian::Bool, + prod!, + tprod! = nothing, + ctprod! = nothing; + S = Vector{T}, +) where {T, I <: Integer} + return LinearOperator(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S) +end + +""" + LinearOperator5(::Type{T}, nrow, ncol, symmetric, hermitian, prod!, + [tprod!=nothing, ctprod!=nothing], + S = Vector{T}) where {T} + Construct a linear operator from functions where the type is specified as the first argument. Change `S` to use LinearOperators on GPU. Notice that the linear operator does not enforce the type, so using a wrong type can @@ -67,7 +108,7 @@ A = [im 1.0; 0.0 1.0] # Complex matrix function mulOp!(res, M, v, α, β) mul!(res, M, v, α, β) end -op = LinearOperator(Float64, 2, 2, false, false, +op = LinearOperator5(Float64, 2, 2, false, false, (res, v, α, β) -> mulOp!(res, A, v, α, β), (res, u, α, β) -> mulOp!(res, transpose(A), u, α, β), (res, w, α, β) -> mulOp!(res, A', w, α, β)) @@ -77,8 +118,6 @@ The error is caused because `Matrix(op)` tries to create a Float64 matrix with t contents of the complex matrix `A`. Using `*` may generate a vector that contains `NaN` values. -This can also happen if you use the 3-args `mul!` function with a preallocated vector such as -`Vector{Float64}(undef, n)`. To fix this issue you will have to deal with the cases `β == 0` and `β != 0` separately: ``` d1 = [2.0; 3.0] @@ -89,21 +128,11 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where T res .= α .* d .* v .+ β .* res end end -op = LinearOperator(Float64, 2, 2, true, true, +op = LinearOperator5(Float64, 2, 2, true, true, (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β)) ``` - -It is possible to create an operator with the 3-args `mul!`. -In this case, using the 5-args `mul!` will generate storage vectors. - -``` -A = rand(2, 2) -op = LinearOperator(Float64, 2, 2, false, false, - (res, v) -> mul!(res, A, v), - (res, w) -> mul!(res, A', w)) -``` """ -function LinearOperator( +function LinearOperator5( ::Type{T}, nrow::I, ncol::I, @@ -114,5 +143,5 @@ function LinearOperator( ctprod! = nothing; S = Vector{T}, ) where {T, I <: Integer} - return LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S) -end + return LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S) +end \ No newline at end of file diff --git a/src/kron.jl b/src/kron.jl index 9be09d16..49d67be7 100644 --- a/src/kron.jl +++ b/src/kron.jl @@ -41,7 +41,7 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator) symm = issymmetric(A) && issymmetric(B) herm = ishermitian(A) && ishermitian(B) nrow, ncol = m * p, n * q - return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!) + return LinearOperator5(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!) end kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B) diff --git a/src/lbfgs.jl b/src/lbfgs.jl index f03cad23..db2ba28c 100644 --- a/src/lbfgs.jl +++ b/src/lbfgs.jl @@ -64,7 +64,7 @@ mutable struct LBFGSOperator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewton nctprod::I end -LBFGSOperator{T}( +LBFGSOperator( nrow::I, ncol::I, symmetric::Bool, @@ -149,7 +149,7 @@ function InverseLBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer} end prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β) - return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data) + return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data) end InverseLBFGSOperator(n::Int; kwargs...) = InverseLBFGSOperator(Float64, n; kwargs...) @@ -199,7 +199,7 @@ function LBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer} end prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β) - return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data) + return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data) end LBFGSOperator(n::I; kwargs...) where {I <: Integer} = LBFGSOperator(Float64, n; kwargs...) diff --git a/src/linalg.jl b/src/linalg.jl index 4ff48088..ab391fe3 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -28,7 +28,7 @@ function opInverse(M::AbstractMatrix{T}; symm = false, herm = false) where {T} prod! = @closure (res, v, α, β) -> mulFact!(res, M, v, α, β) tprod! = @closure (res, u, α, β) -> mulFact!(res, transpose(M), u, α, β) ctprod! = @closure (res, w, α, β) -> mulFact!(res, adjoint(M), w, α, β) - LinearOperator{T}(size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!) + LinearOperator5(T, size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!) end """ @@ -53,7 +53,7 @@ function opCholesky(M::AbstractMatrix; check::Bool = false) tprod! = @closure (res, u, α, β) -> tmulFact!(res, LL, u, α, β) # M.' = conj(M) ctprod! = @closure (res, w, α, β) -> mulFact!(res, LL, w, α, β) S = eltype(LL) - LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!) + LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!) #TODO: use iterative refinement. end @@ -82,7 +82,7 @@ function opLDL(M::AbstractMatrix; check::Bool = false) tprod! = @closure (res, u, α, β) -> tmulFact!(res, LDL, u, α, β) # M.' = conj(M) ctprod! = @closure (res, w, α, β) -> mulFact!(res, LDL, w, α, β) S = eltype(LDL) - return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!) + return LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!) #TODO: use iterative refinement. end @@ -97,7 +97,7 @@ function opLDL(M::Symmetric{T, SparseMatrixCSC{T, Int}}; check::Bool = false) wh tprod! = @closure (res, u) -> ldiv!(res, LDL, u) # M.' = conj(M) ctprod! = @closure (res, w) -> ldiv!(res, LDL, w) S = eltype(LDL) - return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!) + return LinearOperator(S, m, m, isreal(M), true, prod!, tprod!, ctprod!) end function mulHouseholder!(res, h, v, α, β::T) where {T} @@ -117,7 +117,7 @@ The result is `x -> (I - 2 h hᵀ) x`. function opHouseholder(h::AbstractVector{T}) where {T} n = length(h) prod! = @closure (res, v, α, β) -> mulHouseholder!(res, h, v, α, β) # tprod will be inferred - LinearOperator{T}(n, n, isreal(h), true, prod!, nothing, prod!) + LinearOperator5(T, n, n, isreal(h), true, prod!, nothing, prod!) end function mulHermitian!(res, d, L, v, α, β::T) where {T} @@ -139,7 +139,7 @@ function opHermitian(d::AbstractVector{S}, A::AbstractMatrix{T}) where {S, T} L = tril(A, -1) U = promote_type(S, T) prod! = @closure (res, v, α, β) -> mulHermitian!(res, d, L, v, α, β) - LinearOperator{U}(m, m, isreal(A), true, prod!, nothing, nothing) + LinearOperator5(U, m, m, isreal(A), true, prod!, nothing, nothing) end """ diff --git a/src/lsr1.jl b/src/lsr1.jl index 8008bbbd..01753a44 100644 --- a/src/lsr1.jl +++ b/src/lsr1.jl @@ -54,7 +54,7 @@ mutable struct LSR1Operator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewtonO nctprod::I end -LSR1Operator{T}( +LSR1Operator( nrow::I, ncol::I, symmetric::Bool, @@ -114,7 +114,7 @@ function LSR1Operator(T::DataType, n::I; kwargs...) where {I <: Integer} end prod! = @closure (res, x, α, β) -> lsr1_multiply(res, lsr1_data, x, α, β) - return LSR1Operator{T}(n, n, true, true, prod!, nothing, nothing, false, lsr1_data) + return LSR1Operator(n, n, true, true, prod!, nothing, nothing, false, lsr1_data) end LSR1Operator(n::I; kwargs...) where {I <: Integer} = LSR1Operator(Float64, n; kwargs...) diff --git a/src/special-operators.jl b/src/special-operators.jl index 4f658d6a..efd71f89 100644 --- a/src/special-operators.jl +++ b/src/special-operators.jl @@ -48,7 +48,7 @@ Change `S` to use LinearOperators on GPU. """ function opEye(T::DataType, n::Int; S = Vector{T}) prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, n) - LinearOperator{T}(n, n, true, true, prod!, prod!, prod!, S = S) + LinearOperator5(T, n, n, true, true, prod!, prod!, prod!, S = S) end opEye(n::Int) = opEye(Float64, n) @@ -67,7 +67,7 @@ function opEye(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer return opEye(T, nrow) end prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, min(nrow, ncol)) - return LinearOperator{T}(nrow, ncol, false, false, prod!, prod!, prod!, S = S) + return LinearOperator5(T, nrow, ncol, false, false, prod!, prod!, prod!, S = S) end opEye(nrow::I, ncol::I) where {I <: Integer} = opEye(Float64, nrow, ncol) @@ -90,7 +90,7 @@ Change `S` to use LinearOperators on GPU. """ function opOnes(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer} prod! = @closure (res, v, α, β) -> mulOpOnes!(res, v, α, β) - LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S) + LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S) end opOnes(nrow::I, ncol::I) where {I <: Integer} = opOnes(Float64, nrow, ncol) @@ -113,7 +113,7 @@ Change `S` to use LinearOperators on GPU. """ function opZeros(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer} prod! = @closure (res, v, α, β) -> mulOpZeros!(res, v, α, β) - LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S) + LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S) end opZeros(nrow::I, ncol::I) where {I <: Integer} = opZeros(Float64, nrow, ncol) @@ -134,7 +134,7 @@ Diagonal operator with the vector `d` on its main diagonal. function opDiagonal(d::AbstractVector{T}) where {T} prod! = @closure (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β) ctprod! = @closure (res, w, α, β) -> mulSquareOpDiagonal!(res, conj.(d), w, α, β) - LinearOperator{T}(length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d)) + LinearOperator5(T, length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d)) end function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T} @@ -157,7 +157,7 @@ function opDiagonal(nrow::I, ncol::I, d::AbstractVector{T}) where {T, I <: Integ prod! = @closure (res, v, α, β) -> mulOpDiagonal!(res, d, v, α, β, n_min) tprod! = @closure (res, u, α, β) -> mulOpDiagonal!(res, d, u, α, β, n_min) ctprod! = @closure (res, w, α, β) -> mulOpDiagonal!(res, conj.(d), w, α, β, n_min) - LinearOperator{T}(nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d)) + LinearOperator5(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d)) end function mulRestrict!(res, I, v, α, β) @@ -185,7 +185,7 @@ function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I) where {I <: Int nrow = length(Idx) prod! = @closure (res, v, α, β) -> mulRestrict!(res, Idx, v, α, β) tprod! = @closure (res, u, α, β) -> multRestrict!(res, Idx, u, α, β) - return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!) + return LinearOperator5(I, nrow, ncol, false, false, prod!, tprod!, tprod!) end opRestriction(::Colon, ncol::I) where {I <: Integer} = opEye(I, ncol) diff --git a/test/test_callable.jl b/test/test_callable.jl index 87c5c4a8..37f6dd03 100644 --- a/test/test_callable.jl +++ b/test/test_callable.jl @@ -11,7 +11,7 @@ end function test_callable() @testset ExtendedTestSet "Test callable" begin Mv = ones(2) - op = LinearOperator(Float64, 2, 2, true, true, Flip()) + op = LinearOperator5(Float64, 2, 2, true, true, Flip()) @test op * ones(2) == -ones(2) @test op' * ones(2) == -ones(2) @test transpose(op) * ones(2) == -ones(2) diff --git a/test/test_chainrules.jl b/test/test_chainrules.jl index 461a4fad..e2645977 100644 --- a/test/test_chainrules.jl +++ b/test/test_chainrules.jl @@ -13,7 +13,7 @@ function matmulOp(mat::AbstractArray{T}) where {T} end end - return LinearOperator{T}(size(mat, 1), size(mat, 2), false, false, prod!, nothing, ctprod!) + return LinearOperator(T, size(mat, 1), size(mat, 2), false, false, prod!, nothing, ctprod!) end function test_chainrules() diff --git a/test/test_linop.jl b/test/test_linop.jl index f65f1bcd..86207896 100644 --- a/test/test_linop.jl +++ b/test/test_linop.jl @@ -370,7 +370,7 @@ function test_linop() res = copy(res_init) α, β = 2.0, -3.0 - op = LinearOperator( + op = LinearOperator5( ComplexF64, nrow, nrow, @@ -389,7 +389,7 @@ function test_linop() mul!(res, adjoint(op), v, α, β) @test(norm(α * A' * v + β * res_init - res) <= rtol * norm(v)) - op = LinearOperator( + op = LinearOperator5( ComplexF64, nrow, nrow, @@ -506,7 +506,7 @@ function test_linop() function test_func(res) res .= 1.0 .+ im * 1.0 end - op = LinearOperator(ComplexF64, 5, 3, false, false, (res, p, α, β) -> test_func(res)) + op = LinearOperator5(ComplexF64, 5, 3, false, false, (res, p, α, β) -> test_func(res)) @test eltype(op) == ComplexF64 v = rand(5) @test_throws LinearOperatorException transpose(op) * v # cannot be inferred @@ -528,7 +528,7 @@ function test_linop() # Adjoint of a symmetric non-hermitian A = simple_matrix(ComplexF64, 3, 3) A = A + transpose(A) - op = LinearOperator(ComplexF64, 3, 3, true, false, (res, v, α, β) -> mul!(res, A, v)) + op = LinearOperator5(ComplexF64, 3, 3, true, false, (res, v, α, β) -> mul!(res, A, v)) v = rand(3) @test op' * v ≈ A' * v end @@ -543,7 +543,7 @@ function test_linop() res[2] = v[1] + v[2] end for T in (Complex{Float64}, Complex{Float32}, BigFloat, Float64, Float32, Float16, Int32) - op = LinearOperator(T, 2, 2, false, false, prod!, nothing, ctprod!) + op = LinearOperator5(T, 2, 2, false, false, prod!, nothing, ctprod!) w = ones(T, 2) @test eltype(op) == T @test op * w == T[2; 1] @@ -560,10 +560,10 @@ function test_linop() function ctprod2!(res, w, α, β) mul!(res, A', w) end - opC = LinearOperator(ComplexF64, 2, 2, false, false, prod2!, tprod2!, ctprod2!) + opC = LinearOperator5(ComplexF64, 2, 2, false, false, prod2!, tprod2!, ctprod2!) v = simple_vector(ComplexF64, 2) @test A == Matrix(opC) - opF = LinearOperator(Float64, 2, 2, false, false, prod2!, tprod2!, ctprod2!) # The type is a lie + opF = LinearOperator5(Float64, 2, 2, false, false, prod2!, tprod2!, ctprod2!) # The type is a lie @test eltype(opF) == Float64 @test_throws InexactError Matrix(opF) # changed here TypeError to InexactError end