Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate 5-args constructor #282

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions src/TimedOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ mutable struct TimedLinearOperator{T, OP <: AbstractLinearOperator{T}, F, Ft, Fc
ctprod!::Fct
end

TimedLinearOperator{T}(
timer::TimerOutput,
op::AbstractLinearOperator{T},
prod!::F,
tprod!::Ft,
ctprod!::Fct,
) where {T, F, Ft, Fct} =
TimedLinearOperator{T, typeof(op), F, Ft, Fct}(timer, op, prod!, tprod!, ctprod!)

"""
TimedLinearOperator(op)
Creates a linear operator instrumented with timers from TimerOutputs.
Expand All @@ -29,7 +20,7 @@ function TimedLinearOperator(op::AbstractLinearOperator{T}) where {T}
prod!(res, x, α, β) = @timeit timer "prod" op.prod!(res, x, α, β)
tprod!(res, x, α, β) = @timeit timer "tprod" op.tprod!(res, x, α, β)
ctprod!(res, x, α, β) = @timeit timer "ctprod" op.ctprod!(res, x, α, β)
TimedLinearOperator{T}(timer, op, prod!, tprod!, ctprod!)
TimedLinearOperator(timer, op, prod!, tprod!, ctprod!)
end

TimedLinearOperator(op::AdjointLinearOperator) = adjoint(TimedLinearOperator(op.parent))
Expand Down
53 changes: 39 additions & 14 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export AbstractLinearOperator,
AbstractQuasiNewtonOperator,
AbstractDiagonalQuasiNewtonOperator,
LinearOperator,
LinearOperator5,
LinearOperatorException,
hermitian,
ishermitian,
Expand Down Expand Up @@ -61,7 +62,8 @@ mutable struct LinearOperator{T, I <: Integer, F, Ft, Fct, S} <: AbstractLinearO
allocated5::Bool # true for 5-args mul!, false for 3-args mul! until the vectors are allocated
end

function LinearOperator{T}(
function LinearOperator(
::Type{T},
nrow::I,
ncol::I,
symmetric::Bool,
Expand All @@ -75,11 +77,9 @@ function LinearOperator{T}(
S::DataType = Vector{T},
) where {T, I <: Integer, F, Ft, Fct}
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
nargs = get_nargs(prod!)
args5 = (nargs == 4)
(args5 == false) || (nargs != 2) || throw(LinearOperatorException("Invalid number of arguments"))
allocated5 = args5 ? true : false
use_prod5! = args5 ? true : false
args5 = false
allocated5 = false
use_prod5! = false
return LinearOperator{T, I, F, Ft, Fct, S}(
nrow,
ncol,
Expand All @@ -99,21 +99,46 @@ function LinearOperator{T}(
)
end

LinearOperator{T}(
function LinearOperator5(
::Type{T},
nrow::I,
ncol::I,
symmetric::Bool,
hermitian::Bool,
prod!,
tprod!,
ctprod!;
prod!::F,
tprod!::Ft,
ctprod!::Fct,
nprod::I,
ntprod::I,
nctprod::I;
S::DataType = Vector{T},
) where {T, I <: Integer} =
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
) where {T, I <: Integer, F, Ft, Fct}
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
args5 = true
allocated5 = true
use_prod5! = true
return LinearOperator{T, I, F, Ft, Fct, S}(
nrow,
ncol,
symmetric,
hermitian,
prod!,
tprod!,
ctprod!,
nprod,
ntprod,
nctprod,
args5,
use_prod5!,
Mv5,
Mtu5,
allocated5,
)
end

# create operator from other operators with +, *, vcat,...
function CompositeLinearOperator(
T::DataType,
::Type{T},
nrow::I,
ncol::I,
symmetric::Bool,
Expand All @@ -123,7 +148,7 @@ function CompositeLinearOperator(
ctprod!::Fct,
args5::Bool;
S::DataType = Vector{T},
) where {I <: Integer, F, Ft, Fct}
) where {T, I <: Integer, F, Ft, Fct}
Mv5, Mtu5 = S(undef, 0), S(undef, 0)
allocated5 = true
use_prod5! = true
Expand Down
65 changes: 47 additions & 18 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function LinearOperator(
prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β)
tprod! = @closure (res, u, α, β) -> mul!(res, transpose(M), u, α, β)
ctprod! = @closure (res, w, α, β) -> mul!(res, adjoint(M), w, α, β)
LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
end

"""
Expand Down Expand Up @@ -58,6 +58,47 @@ end
[tprod!=nothing, ctprod!=nothing],
S = Vector{T}) where {T}

Construct a linear operator from functions where the type is specified as the first argument.
Change `S` to use LinearOperators on GPU.
```
A = rand(2, 2)
op = LinearOperator(Float64, 2, 2, false, false,
(res, v) -> mul!(res, A, v),
(res, w) -> mul!(res, A', w))
```

Notice that the linear operator does not enforce the type, so using a wrong type can
result in errors. For instance,
```
A = [im 1.0; 0.0 1.0] # Complex matrix
op = LinearOperator5(Float64, 2, 2, false, false,
(res, v) -> mul!(res, A, v),
(res, u) -> mul!(res, transpose(A), u),
(res, w) -> mul!(res, A', w))
Matrix(op) # InexactError
```
The error is caused because `Matrix(op)` tries to create a Float64 matrix with the
contents of the complex matrix `A`.
"""
function LinearOperator(
::Type{T},
nrow::I,
ncol::I,
symmetric::Bool,
hermitian::Bool,
prod!,
tprod! = nothing,
ctprod! = nothing;
S = Vector{T},
) where {T, I <: Integer}
return LinearOperator(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
end

"""
LinearOperator5(::Type{T}, nrow, ncol, symmetric, hermitian, prod!,
[tprod!=nothing, ctprod!=nothing],
S = Vector{T}) where {T}

Construct a linear operator from functions where the type is specified as the first argument.
Change `S` to use LinearOperators on GPU.
Notice that the linear operator does not enforce the type, so using a wrong type can
Expand All @@ -67,7 +108,7 @@ A = [im 1.0; 0.0 1.0] # Complex matrix
function mulOp!(res, M, v, α, β)
mul!(res, M, v, α, β)
end
op = LinearOperator(Float64, 2, 2, false, false,
op = LinearOperator5(Float64, 2, 2, false, false,
(res, v, α, β) -> mulOp!(res, A, v, α, β),
(res, u, α, β) -> mulOp!(res, transpose(A), u, α, β),
(res, w, α, β) -> mulOp!(res, A', w, α, β))
Expand All @@ -77,8 +118,6 @@ The error is caused because `Matrix(op)` tries to create a Float64 matrix with t
contents of the complex matrix `A`.

Using `*` may generate a vector that contains `NaN` values.
This can also happen if you use the 3-args `mul!` function with a preallocated vector such as
`Vector{Float64}(undef, n)`.
To fix this issue you will have to deal with the cases `β == 0` and `β != 0` separately:
```
d1 = [2.0; 3.0]
Expand All @@ -89,21 +128,11 @@ function mulSquareOpDiagonal!(res, d, v, α, β::T) where T
res .= α .* d .* v .+ β .* res
end
end
op = LinearOperator(Float64, 2, 2, true, true,
op = LinearOperator5(Float64, 2, 2, true, true,
(res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β))
```

It is possible to create an operator with the 3-args `mul!`.
In this case, using the 5-args `mul!` will generate storage vectors.

```
A = rand(2, 2)
op = LinearOperator(Float64, 2, 2, false, false,
(res, v) -> mul!(res, A, v),
(res, w) -> mul!(res, A', w))
```
"""
function LinearOperator(
function LinearOperator5(
::Type{T},
nrow::I,
ncol::I,
Expand All @@ -114,5 +143,5 @@ function LinearOperator(
ctprod! = nothing;
S = Vector{T},
) where {T, I <: Integer}
return LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
end
return LinearOperator5(T, nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, 0, 0, 0, S = S)
end
2 changes: 1 addition & 1 deletion src/kron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function kron(A::AbstractLinearOperator, B::AbstractLinearOperator)
symm = issymmetric(A) && issymmetric(B)
herm = ishermitian(A) && ishermitian(B)
nrow, ncol = m * p, n * q
return LinearOperator{T}(nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
return LinearOperator5(T, nrow, ncol, symm, herm, prod!, tprod!, ctprod!)
end

kron(A::AbstractMatrix, B::AbstractLinearOperator) = kron(LinearOperator(A), B)
Expand Down
6 changes: 3 additions & 3 deletions src/lbfgs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ mutable struct LBFGSOperator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewton
nctprod::I
end

LBFGSOperator{T}(
LBFGSOperator(
nrow::I,
ncol::I,
symmetric::Bool,
Expand Down Expand Up @@ -149,7 +149,7 @@ function InverseLBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer}
end

prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β)
return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data)
return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, true, lbfgs_data)
end

InverseLBFGSOperator(n::Int; kwargs...) = InverseLBFGSOperator(Float64, n; kwargs...)
Expand Down Expand Up @@ -199,7 +199,7 @@ function LBFGSOperator(T::DataType, n::I; kwargs...) where {I <: Integer}
end

prod! = @closure (res, x, α, β) -> lbfgs_multiply(res, lbfgs_data, x, α, β)
return LBFGSOperator{T}(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data)
return LBFGSOperator(n, n, true, true, prod!, prod!, prod!, false, lbfgs_data)
end

LBFGSOperator(n::I; kwargs...) where {I <: Integer} = LBFGSOperator(Float64, n; kwargs...)
Expand Down
12 changes: 6 additions & 6 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function opInverse(M::AbstractMatrix{T}; symm = false, herm = false) where {T}
prod! = @closure (res, v, α, β) -> mulFact!(res, M, v, α, β)
tprod! = @closure (res, u, α, β) -> mulFact!(res, transpose(M), u, α, β)
ctprod! = @closure (res, w, α, β) -> mulFact!(res, adjoint(M), w, α, β)
LinearOperator{T}(size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
LinearOperator5(T, size(M, 2), size(M, 1), symm, herm, prod!, tprod!, ctprod!)
end

"""
Expand All @@ -53,7 +53,7 @@ function opCholesky(M::AbstractMatrix; check::Bool = false)
tprod! = @closure (res, u, α, β) -> tmulFact!(res, LL, u, α, β) # M.' = conj(M)
ctprod! = @closure (res, w, α, β) -> mulFact!(res, LL, w, α, β)
S = eltype(LL)
LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
#TODO: use iterative refinement.
end

Expand Down Expand Up @@ -82,7 +82,7 @@ function opLDL(M::AbstractMatrix; check::Bool = false)
tprod! = @closure (res, u, α, β) -> tmulFact!(res, LDL, u, α, β) # M.' = conj(M)
ctprod! = @closure (res, w, α, β) -> mulFact!(res, LDL, w, α, β)
S = eltype(LDL)
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
return LinearOperator5(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
#TODO: use iterative refinement.
end

Expand All @@ -97,7 +97,7 @@ function opLDL(M::Symmetric{T, SparseMatrixCSC{T, Int}}; check::Bool = false) wh
tprod! = @closure (res, u) -> ldiv!(res, LDL, u) # M.' = conj(M)
ctprod! = @closure (res, w) -> ldiv!(res, LDL, w)
S = eltype(LDL)
return LinearOperator{S}(m, m, isreal(M), true, prod!, tprod!, ctprod!)
return LinearOperator(S, m, m, isreal(M), true, prod!, tprod!, ctprod!)
end

function mulHouseholder!(res, h, v, α, β::T) where {T}
Expand All @@ -117,7 +117,7 @@ The result is `x -> (I - 2 h hᵀ) x`.
function opHouseholder(h::AbstractVector{T}) where {T}
n = length(h)
prod! = @closure (res, v, α, β) -> mulHouseholder!(res, h, v, α, β) # tprod will be inferred
LinearOperator{T}(n, n, isreal(h), true, prod!, nothing, prod!)
LinearOperator5(T, n, n, isreal(h), true, prod!, nothing, prod!)
end

function mulHermitian!(res, d, L, v, α, β::T) where {T}
Expand All @@ -139,7 +139,7 @@ function opHermitian(d::AbstractVector{S}, A::AbstractMatrix{T}) where {S, T}
L = tril(A, -1)
U = promote_type(S, T)
prod! = @closure (res, v, α, β) -> mulHermitian!(res, d, L, v, α, β)
LinearOperator{U}(m, m, isreal(A), true, prod!, nothing, nothing)
LinearOperator5(U, m, m, isreal(A), true, prod!, nothing, nothing)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/lsr1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ mutable struct LSR1Operator{T, I <: Integer, F, Ft, Fct} <: AbstractQuasiNewtonO
nctprod::I
end

LSR1Operator{T}(
LSR1Operator(
nrow::I,
ncol::I,
symmetric::Bool,
Expand Down Expand Up @@ -114,7 +114,7 @@ function LSR1Operator(T::DataType, n::I; kwargs...) where {I <: Integer}
end

prod! = @closure (res, x, α, β) -> lsr1_multiply(res, lsr1_data, x, α, β)
return LSR1Operator{T}(n, n, true, true, prod!, nothing, nothing, false, lsr1_data)
return LSR1Operator(n, n, true, true, prod!, nothing, nothing, false, lsr1_data)
end

LSR1Operator(n::I; kwargs...) where {I <: Integer} = LSR1Operator(Float64, n; kwargs...)
Expand Down
14 changes: 7 additions & 7 deletions src/special-operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Change `S` to use LinearOperators on GPU.
"""
function opEye(T::DataType, n::Int; S = Vector{T})
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, n)
LinearOperator{T}(n, n, true, true, prod!, prod!, prod!, S = S)
LinearOperator5(T, n, n, true, true, prod!, prod!, prod!, S = S)
end

opEye(n::Int) = opEye(Float64, n)
Expand All @@ -67,7 +67,7 @@ function opEye(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer
return opEye(T, nrow)
end
prod! = @closure (res, v, α, β) -> mulOpEye!(res, v, α, β, min(nrow, ncol))
return LinearOperator{T}(nrow, ncol, false, false, prod!, prod!, prod!, S = S)
return LinearOperator5(T, nrow, ncol, false, false, prod!, prod!, prod!, S = S)
end

opEye(nrow::I, ncol::I) where {I <: Integer} = opEye(Float64, nrow, ncol)
Expand All @@ -90,7 +90,7 @@ Change `S` to use LinearOperators on GPU.
"""
function opOnes(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
prod! = @closure (res, v, α, β) -> mulOpOnes!(res, v, α, β)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
end

opOnes(nrow::I, ncol::I) where {I <: Integer} = opOnes(Float64, nrow, ncol)
Expand All @@ -113,7 +113,7 @@ Change `S` to use LinearOperators on GPU.
"""
function opZeros(T::DataType, nrow::I, ncol::I; S = Vector{T}) where {I <: Integer}
prod! = @closure (res, v, α, β) -> mulOpZeros!(res, v, α, β)
LinearOperator{T}(nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
LinearOperator5(T, nrow, ncol, nrow == ncol, nrow == ncol, prod!, prod!, prod!, S = S)
end

opZeros(nrow::I, ncol::I) where {I <: Integer} = opZeros(Float64, nrow, ncol)
Expand All @@ -134,7 +134,7 @@ Diagonal operator with the vector `d` on its main diagonal.
function opDiagonal(d::AbstractVector{T}) where {T}
prod! = @closure (res, v, α, β) -> mulSquareOpDiagonal!(res, d, v, α, β)
ctprod! = @closure (res, w, α, β) -> mulSquareOpDiagonal!(res, conj.(d), w, α, β)
LinearOperator{T}(length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d))
LinearOperator5(T, length(d), length(d), true, isreal(d), prod!, prod!, ctprod!, S = typeof(d))
end

function mulOpDiagonal!(res, d, v, α, β::T, n_min) where {T}
Expand All @@ -157,7 +157,7 @@ function opDiagonal(nrow::I, ncol::I, d::AbstractVector{T}) where {T, I <: Integ
prod! = @closure (res, v, α, β) -> mulOpDiagonal!(res, d, v, α, β, n_min)
tprod! = @closure (res, u, α, β) -> mulOpDiagonal!(res, d, u, α, β, n_min)
ctprod! = @closure (res, w, α, β) -> mulOpDiagonal!(res, conj.(d), w, α, β, n_min)
LinearOperator{T}(nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d))
LinearOperator5(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, S = typeof(d))
end

function mulRestrict!(res, I, v, α, β)
Expand Down Expand Up @@ -185,7 +185,7 @@ function opRestriction(Idx::LinearOperatorIndexType{I}, ncol::I) where {I <: Int
nrow = length(Idx)
prod! = @closure (res, v, α, β) -> mulRestrict!(res, Idx, v, α, β)
tprod! = @closure (res, u, α, β) -> multRestrict!(res, Idx, u, α, β)
return LinearOperator{I}(nrow, ncol, false, false, prod!, tprod!, tprod!)
return LinearOperator5(I, nrow, ncol, false, false, prod!, tprod!, tprod!)
end

opRestriction(::Colon, ncol::I) where {I <: Integer} = opEye(I, ncol)
Expand Down
Loading
Loading