Skip to content

Commit

Permalink
Implement new Multiplier struct and its use with KrylovKit
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Jan 22, 2025
1 parent 5f48b35 commit dd436a6
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 18 deletions.
71 changes: 57 additions & 14 deletions ext/KrylovKitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ using CommonSolve: CommonSolve
using Setfield: Setfield, @set
using NamedTupleTools: NamedTupleTools, delete

using Rimu: Rimu, AbstractDVec, AbstractHamiltonian, IsDeterministic, PDVec, DVec,
PDWorkingMemory, scale!!, working_memory, zerovector, dimension, replace_keys
using Rimu: Rimu, AbstractDVec, AbstractHamiltonian, AbstractOperator, IsDeterministic,
starting_address, PDVec, DVec, PDWorkingMemory,
scale!!, working_memory, zerovector, dimension, replace_keys

using Rimu.ExactDiagonalization: MatrixEDSolver, KrylovKitSolver,
KrylovKitDirectEDSolver,
LazyDVecs, EDResult, LazyCoefficientVectorsDVecs
LazyDVecs, EDResult, LazyCoefficientVectorsDVecs, Multiplier

const U = Union{Symbol,EigSorter}

Expand Down Expand Up @@ -55,6 +56,45 @@ function KrylovKit.eigsolve(
)
end

function _prepare_multiplier(
ham, vec; basis=nothing, starting_address=starting_address(ham), full_basis=true
)
if issymmetric(ham) && (isnothing(vec) || isreal(vec))
eltype = Float64
else
eltype = ComplexF64
end
if isnothing(basis)
prop = Multiplier(ham, starting_address; full_basis, eltype)
else
prop = Multiplier(ham, basis; eltype)
end
end

function KrylovKit.eigsolve(
ham::AbstractOperator, vec::Vector, howmany::Int=1, which::U=:LR;
basis=nothing, starting_address=starting_address(ham), full_basis=true, kwargs...
)
# Change the type of `vec` to float, if needed.
v = scale!!(vec, 1.0)
prop = _prepare_multiplier(ham, v; basis, starting_address, full_basis)
return eigsolve(
prop, v, howmany, which;
ishermitian=ishermitian(ham), issymmetric=issymmetric(ham), kwargs...
)
end
function KrylovKit.eigsolve(
ham::AbstractOperator, howmany::Int=1, which::U=:LR;
basis=nothing, starting_address=starting_address(ham), full_basis=true, kwargs...
)
prop = _prepare_multiplier(ham, nothing; basis, starting_address, full_basis)
v = rand(eltype(prop), size(prop, 1))
return eigsolve(
prop, v, howmany, which;
ishermitian=ishermitian(ham), issymmetric=issymmetric(ham), kwargs...
)
end

# solve for KrylovKit solvers: prepare arguments for `KrylovKit.eigsolve`
function CommonSolve.solve(s::S; kwargs...
) where {S<:Union{MatrixEDSolver{<:KrylovKitSolver},KrylovKitDirectEDSolver}}
Expand Down Expand Up @@ -92,10 +132,6 @@ function _kk_eigsolve(s::MatrixEDSolver{<:KrylovKitSolver}, howmany, which, kw_n
# solve the problem
vals, vecs, info = eigsolve(s.basissetrep.sparse_matrix, x0, howmany, which; kw_nt...)
success = info.converged howmany
if !success
@warn "KrylovKit.eigsolve did not converge for all requested eigenvalues:" *
" $(info.converged) converged out of $howmany requested value(s)."
end

return EDResult(
s.algorithm,
Expand All @@ -113,13 +149,20 @@ end

# solve with KrylovKit direct
function _kk_eigsolve(s::KrylovKitDirectEDSolver, howmany, which, kw_nt)

vals, vecs, info = eigsolve(s.problem.hamiltonian, s.v0, howmany, which; kw_nt...)
success = info.converged howmany
if !success
@warn "KrylovKit.eigsolve did not converge for all requested eigenvalues:" *
" $(info.converged) converged out of $howmany requested value(s)."
prop = _prepare_multiplier(s.problem.hamiltonian, s.v0#=TODO: new args go here=#)
if isnothing(s.v0)
x0 = rand(size(prop, 1))
else
x0 = zeros(eltype(prop), size(prop, 1))
for (k, v) in pairs(s.v0)
x0[prop.mapping[k]] = v
end
end
vals, vecs, info = eigsolve(
prop, x0, howmany, which;
issymmetric=issymmetric(prop), ishermitian=ishermitian(prop), kw_nt...
)
success = info.converged howmany

basis = keys(vecs[1])

Expand All @@ -128,7 +171,7 @@ function _kk_eigsolve(s::KrylovKitDirectEDSolver, howmany, which, kw_nt)
s.problem,
vals,
vecs,
LazyCoefficientVectorsDVecs(vecs, basis),
LazyDVecs(vecs, basis),
basis,
info,
howmany,
Expand Down
8 changes: 4 additions & 4 deletions src/ExactDiagonalization/ExactDiagonalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ provided by external packages.
"""
module ExactDiagonalization

using LinearAlgebra: LinearAlgebra, eigen!, ishermitian, Matrix
using LinearAlgebra: LinearAlgebra, eigen!, issymmetric, ishermitian, Matrix, dot
using SparseArrays: SparseArrays, nnz, nzrange, sparse
using CommonSolve: CommonSolve, solve, init
using VectorInterface: VectorInterface, add
Expand All @@ -29,13 +29,12 @@ using StaticArrays: setindex

using Rimu: Rimu, DictVectors, Hamiltonians, Interfaces, BitStringAddresses, replace_keys,
clean_and_warn_if_others_present
using ..Interfaces: AbstractDVec, AbstractHamiltonian, AdjointUnknown,
using ..Interfaces: AbstractDVec, AbstractHamiltonian, AbstractOperator, AdjointUnknown,
diagonal_element, offdiagonals, starting_address, LOStructure, IsHermitian
using ..BitStringAddresses: AbstractFockAddress, BoseFS, FermiFS, CompositeFS, near_uniform
using ..DictVectors: FrozenDVec, PDVec, DVec
using ..Hamiltonians: allows_address_type, check_address_type, dimension,
ParitySymmetry, TimeReversalSymmetry

ParitySymmetry, TimeReversalSymmetry, AbstractOperator

export ExactDiagonalizationProblem, KrylovKitSolver, LinearAlgebraSolver
export ArpackSolver, LOBPCGSolver
Expand All @@ -47,6 +46,7 @@ export sparse # from SparseArrays
include("basis_breadth_first_search.jl")
include("basis_fock.jl")
include("basis_set_representation.jl")
include("multiplier.jl")
include("algorithms.jl")
include("exact_diagonalization_problem.jl")
include("init_and_solvers.jl")
Expand Down
162 changes: 162 additions & 0 deletions src/ExactDiagonalization/multiplier.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
Multiplier(::AbstractOperator{T}, basis; eltype=T)
Multiplier(::AbstractHamiltonian{T}, [address]; full_basis=true, eltype=T)
Wrapper for an [`AbstractOperator`](@ref) and a basis that allows multiplying regular Julia
vectors with the operator.
The `eltype` argument can be used to change the eltype of the internal buffers, e.g. for
multiplying complex vectors with real operators.
If an [`AbstractHamiltonian`](@ref) with no `basis` is passed, the basis is constructed
automatically. In that case, when `full_basis=true` the entire basis is constructed from an
address as [`build_basis`](@ref)`(address)`, otherwise it is constructed as
[`build_basis`](@ref)`(hamiltonian, address)`. You may want to set `full_basis=false` when
dealing with Hamiltonians that block, such as [`HubbardMom1D`](@ref).
Supports calling, `Base.:*`, `mul!` and the three-argument `dot`.
## Example
```julia
julia> H = HubbardReal1D(BoseFS(1, 1, 1, 1));
julia> bsr = BasisSetRepresentation(H);
julia> v = ones(length(bsr.basis));
julia> w1 = bsr.sparse_matrix * v;
julia> mul = ExactDiagonalization.Multiplier(H, bsr.basis);
julia> w2 = mul * v;
julia> w1 ≈ w2
true
julia> dot(w1, bsr.sparse_matrix, v) ≈ dot(w1, mul, v)
true
```
"""
struct Multiplier{T,H<:AbstractOperator,A,I}
hamiltonian::H
basis::Vector{A}
mapping::Dict{A,I}
size::Tuple{Int,Int}
buffer::Matrix{T}
indices::Vector{UnitRange{Int}}
end
function Multiplier(
hamiltonian::H, basis::Vector{A}; eltype=eltype(H)
) where {A,H<:AbstractOperator}
I = length(basis) > typemax(Int32) ? Int64 : Int32
T = eltype
mapping = Dict(b => I(i) for (b, i) in zip(basis, eachindex(basis)))
threads = Threads.nthreads()
buffer = zeros(T, (length(basis), threads))

chunk_size = length(basis) ÷ threads
prev = 0
indices = UnitRange{Int}[]
for t in 1:threads - 1
push!(indices, prev+1:prev+chunk_size)
prev += chunk_size
end
push!(indices, prev+1:length(basis))

return Multiplier{T,H,A,I}(
hamiltonian, basis, mapping, (length(basis), length(basis)), buffer, indices
)
end
function Multiplier(
hamiltonian::AbstractHamiltonian,
address::AbstractFockAddress=starting_address(hamiltonian);
full_basis=true, eltype=eltype(hamiltonian),
)
if full_basis
basis = build_basis(address)
else
basis = build_basis(hamiltonian, address)
end
return Multiplier(hamiltonian, basis)
end
function Base.show(io::IO, mul::Multiplier{T}) where {T}
print(io, "Multiplier{$T}($(mul.hamiltonian))")
end

Base.size(mul::Multiplier) = mul.size
Base.size(mul::Multiplier, i) = mul.size[i]
Base.eltype(::Type{Multiplier{T}}) where {T} = T
Base.eltype(::Multiplier{T}) where {T} = T
LinearAlgebra.issymmetric(mul::Multiplier) = issymmetric(mul.hamiltonian)
LinearAlgebra.ishermitian(mul::Multiplier) = ishermitian(mul.hamiltonian)

function Base.adjoint(mul::Multiplier{T,<:Any,A,I}) where {T,A,I}
hamiltonian = mul.hamiltonian'
return Multiplier{T,typeof(hamiltonian),A,I}(
hamiltonian, mul.basis, mul.mapping, mul.size,
)
end

function LinearAlgebra.mul!(dst, mul::Multiplier{T}, src) where {T}
@boundscheck begin
length(src) == size(mul, 2) || throw(DimensionMismatch("operator has size $(size(mul)), vector has length $(length(src))"))
length(dst) == size(mul, 1) || throw(DimensionMismatch("operator has size $(size(mul)), output vector has length $(length(dst))"))
@assert size(mul.buffer, 1) == length(src)
end
H = mul.hamiltonian
basis = mul.basis
mapping = mul.mapping
buffer = mul.buffer
indices = mul.indices

@inbounds Threads.@threads for t in 1:size(mul.buffer, 2)
buffer[:, t] .= zero(T)
for i in indices[t]
addr1 = mul.basis[i]
val1 = src[i]
buffer[i, t] += diagonal_element(H, addr1) * val1
for (addr2, elem) in offdiagonals(H, addr1)
j = get(mapping, addr2, 0)
!iszero(j) && (buffer[j, t] += elem * val1)
end
end
end
return sum!(dst, buffer)
end

function (mul::Multiplier)(src)
dst = zeros(length(src))
return LinearAlgebra.mul!(dst, mul, src)
end

Base.:*(mul, src) = mul(src)

function LinearAlgebra.dot(dst, mul::Multiplier, src)
@boundscheck begin
length(src) == size(mul, 2) || throw(DimensionMismatch("operator has size $(size(mul)), vector has length $(length(src))"))
length(dst) == size(mul, 1) || throw(DimensionMismatch("operator has size $(size(mul)), output vector has length $(length(dst))"))
@assert size(mul.buffer, 1) == length(src)
end

H = mul.hamiltonian
basis = mul.basis
mapping = mul.mapping
buffer = mul.buffer
indices = mul.indices

@inbounds Threads.@threads for t in 1:size(mul.buffer, 2)
buffer[1, t] = result = zero(eltype(buffer))
for i in indices[t]
addr1 = mul.basis[i]
val1 = src[i]
result += conj(dst[i]) * diagonal_element(H, addr1) * val1
for (addr2, elem) in offdiagonals(H, addr1)
j = get(mapping, addr2, 0)
result += conj(get(dst, j, 0.0)) * elem * val1
end
end
buffer[1, t] = result
end
return sum(buffer[1, t] for t in 1:size(mul.buffer, 2))
end

0 comments on commit dd436a6

Please sign in to comment.