Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for immutable states/operators #76

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
using Pkg

DOCUMENTER_VERSION =
[p for (uuid, p) in Pkg.dependencies() if p.name == "Documenter"][1].version
if DOCUMENTER_VERSION <= v"1.3.0"
Pkg.develop("Documenter")
end

using Documenter
using QuantumPropagators
using QuantumPropagators: AbstractPropagator, set_t!, set_state!
Expand Down Expand Up @@ -35,10 +29,6 @@ links = InterLinks(
"https://github.com/KristofferC/TimerOutputs.jl",
joinpath(@__DIR__, "src", "inventories", "TimerOutputs.toml"),
),
# We'll use `@extref` for links from docstrings to sections so that the
# docstrings can also be rendered as part of the QuantumControl
# documentation.
"QuantumPropagators" => "https://juliaquantumcontrol.github.io/QuantumPropagators.jl/$DEV_OR_STABLE",
"QuantumControlBase" => "https://juliaquantumcontrol.github.io/QuantumControlBase.jl/$DEV_OR_STABLE",
"ComponentArrays" => (
"https://jonniedie.github.io/ComponentArrays.jl/stable/",
Expand Down Expand Up @@ -89,7 +79,8 @@ makedocs(;
"https://juliaquantumcontrol.github.io/QuantumControl.jl/dev/assets/topbar/topbar.js"
),
],
footer="[$NAME.jl]($GITHUB) v$VERSION docs powered by [Documenter.jl](https://github.com/JuliaDocs/Documenter.jl)."
footer="[$NAME.jl]($GITHUB) v$VERSION docs powered by [Documenter.jl](https://github.com/JuliaDocs/Documenter.jl).",
size_threshold_ignore=["api/quantumpropagators.md",]
),
pages=[
"Home" => "index.md",
Expand All @@ -101,7 +92,7 @@ makedocs(;
hide("Benchmarks" => "benchmarks.md", [joinpath("benchmarks", "profiling.md")]),
"API" => "api/quantumpropagators.md",
"References" => "references.md",
]
],
)

println("Finished makedocs")
Expand Down
1 change: 1 addition & 0 deletions ext/QuantumPropagatorsODEExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ function set_state!(propagator::ODEPropagator, state)
ODE.set_u!(propagator.integrator, state)
end
end
return propagator.state
end


Expand Down
14 changes: 6 additions & 8 deletions src/cheby.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function cheby!(Ψ, H, dt, wrk; kwargs...)

Δ = wrk.Δ
β::Float64 = (Δ / 2) + E_min # "normfactor"
@assert abs(dt) ≈ abs(wrk.dt) "wrk was initialized for dt=$(wrk.dt), not dt=$dt"
@assert abs(dt) ≈ abs(wrk.dt) "wrk was initialized for dt=$(wrk.dt), not dt=abs($dt)"
if dt > 0
c = -2im / Δ
else
Expand Down Expand Up @@ -228,7 +228,7 @@ function cheby(Ψ, H, dt, wrk; kwargs...)

Δ = wrk.Δ
β::Float64 = (Δ / 2) + E_min # "normfactor"
@assert dt ≈ wrk.dt "wrk was initialized for dt=$(wrk.dt), not dt=$dt"
@assert abs(dt) ≈ wrk.dt "wrk was initialized for dt=$(wrk.dt), not dt=abs($dt)"
if dt > 0
c = -2im / Δ
else
Expand All @@ -237,9 +237,6 @@ function cheby(Ψ, H, dt, wrk; kwargs...)
a = wrk.coeffs
ϵ = wrk.limit
@assert length(a) > 1 "Need at least 2 Chebychev coefficients"
v0 = wrk.v0
v1 = wrk.v1
v2 = wrk.v2

v0 = Ψ
Ψ = a[1] * v0
Expand All @@ -256,16 +253,17 @@ function cheby(Ψ, H, dt, wrk; kwargs...)
@timeit_debug wrk.timing_data "matrix-vector product" begin
v2 = H * v1
end
v2 += -v1 * β
v2 = c * v2
if check_normalization
v2 = c * (v2 - v1 * β)
map_norm = abs(dot(v1, v2)) / (2 * norm(v1)^2)
@assert(
map_norm <= (1.0 + ϵ),
"Incorrect normalization (E_min=$(E_min), Δ=$(Δ))"
)
v2 += v0
else
v2 = c * (v2 - β * v1) + v0
end
v2 += v0

Ψ += a[i] * v2

Expand Down
5 changes: 3 additions & 2 deletions src/cheby_propagator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function reinit_prop!(
transform_control_ranges=_transform_control_ranges,
_...
)
set_state!(propagator, state)
state = set_state!(propagator, state)

wrk = propagator.wrk
need_to_recalculate_cheby_coeffs = false
Expand Down Expand Up @@ -355,8 +355,8 @@ function prop_step!(propagator::ChebyPropagator)
end
tlist = getfield(propagator, :tlist)
(0 < n < length(tlist)) || return nothing
_pwc_set_genop!(propagator, n)
if propagator.inplace
H = _pwc_set_genop!(propagator, n)
Cheby.cheby!(
Ψ,
H,
Expand All @@ -365,6 +365,7 @@ function prop_step!(propagator::ChebyPropagator)
check_normalization=propagator.check_normalization
)
else
H = _pwc_get_genop(propagator, n)
Ψ = Cheby.cheby(
Ψ,
H,
Expand Down
103 changes: 55 additions & 48 deletions src/interfaces/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ using ..Generators: Generator
"""Check the dynamical `generator` for propagating `state` over `tlist`.

```julia
@test check_generator(generator; state, tlist,
for_mutable_state=true, for_immutable_state=true,
for_pwc=true, for_time_continuous=false,
for_expval=true, for_parameterization=false,
atol=1e-14, quiet=false)
@test check_generator(
generator; state, tlist,
for_mutable_operator=true, for_immutable_operator=true,
for_mutable_state=true, for_immutable_state=true,
for_pwc=true, for_time_continuous=false,
for_expval=true, for_parameterization=false,
atol=1e-14, quiet=false)
```

verifies the given `generator`:
Expand All @@ -27,14 +29,16 @@ If `for_pwc` (default):
* [`evaluate(generator, tlist, n)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* [`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined
* If `for_mutable_operator`,
[`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined

If `for_time_continuous`:

* [`evaluate(generator, t)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* [`evaluate!(op, generator, t)`](@ref evaluate!) must be defined
* If `for_mutable_operator`, [`evaluate!(op, generator, t)`](@ref evaluate!)
must be defined

If `for_parameterization` (may require the `RecursiveArrayTools` package to be
loaded):
Expand All @@ -51,6 +55,8 @@ function check_generator(
generator;
state,
tlist,
for_mutable_operator=true,
for_immutable_operator=true,
for_mutable_state=true,
for_immutable_state=true,
for_expval=true,
Expand Down Expand Up @@ -208,26 +214,27 @@ function check_generator(
success = false
end

try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
if for_mutable_operator
try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end
try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end
end

end
Expand Down Expand Up @@ -284,28 +291,28 @@ function check_generator(
success = false
end

try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end])
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end]; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
if for_mutable_operator
try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end])
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end
try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end]; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end
end

end


Expand Down
8 changes: 7 additions & 1 deletion src/interfaces/propagator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ initialized with [`init_prop`](@ref).
`propagator.state`.
* [`set_state!(propagator, state)`](@ref set_state!) for an in-place propagator
must overwrite `propagator.state` in-place.
* [`set_state!`](@ref) must return the set `propagator.state`
* In a [`PiecewisePropagator`](@ref), `propagator.parameters` must be a dict
mapping controls to a vector of values, one for each interval on
`propagator.tlist`
Expand Down Expand Up @@ -236,7 +237,7 @@ function check_propagator(
end

try
set_state!(propagator, Ψ₀)
Ψ = set_state!(propagator, Ψ₀)
if norm(propagator.state - Ψ₀) > atol
quiet ||
@error "$(px)`set_state!(propagator, state)` must set `propagator.state`"
Expand All @@ -249,6 +250,11 @@ function check_propagator(
success = false
end
end
if propagator.state ≢ Ψ
quiet ||
@error "$(px)`set_state!(propagator, state)` must return `propagator.state`."
success = false
end
catch exc
quiet || @error(
"$(px)`set_state!(propagator, state)` must be defined.",
Expand Down
3 changes: 1 addition & 2 deletions src/newton_propagator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,15 @@ init_prop(state, generator, tlist, method::Val{:newton}; kwargs...) =
function prop_step!(propagator::NewtonPropagator)
@timeit_debug propagator.timing_data "prop_step!" begin
Ψ = propagator.state
H = propagator.genop
n = propagator.n # index of interval we're going to propagate
tlist = getfield(propagator, :tlist)
(0 < n < length(tlist)) || return nothing
dt = tlist[n+1] - tlist[n]
if propagator.backward
dt = -dt
end
_pwc_set_genop!(propagator, n)
if propagator.inplace
H = _pwc_set_genop!(propagator, n)
Newton.newton!(
Ψ,
H,
Expand Down
7 changes: 4 additions & 3 deletions src/ode_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ H = evaluate(f.generator, t; vals_dict=vals_dict)
where `vals_dict` may be a dictionary mapping controls to values (set as the
parameters `p` of the underlying ODE solver).

If [`QuantumPropagators.enable_timings()` has been called](@extref QuantumPropagators TimerOutputs),
If [`QuantumPropagators.enable_timings()`](@ref
QuantumPropagators.enable_timings) has been called,
profiling data is collected in `f.timing_data`.
"""
function ode_function(generator::GT, tlist; c=-1im, _timing_data=TimerOutput()) where {GT}
Expand Down Expand Up @@ -74,9 +75,9 @@ end

function (f::QuantumODEFunction)(u, p, t)
@timeit_debug f.timing_data "operator evaluation" begin
evaluate!(f.operator, f.generator, t; vals_dict=p)
H = evaluate(f.generator, t; vals_dict=p)
end
@timeit_debug f.timing_data "matrix-vector product" begin
return f.c * f.operator * u
return f.c * H * u
end
end
2 changes: 2 additions & 0 deletions src/propagate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ function propagate(
inplace=true, # cf. default of init_prop
for_expval=true, # undocumented
for_immutable_state=true, # undocumented
for_mutable_operator=inplace, # undocumented
for_mutable_state=inplace, # undocumented
kwargs...
)
Expand Down Expand Up @@ -208,6 +209,7 @@ function propagate(
state=state,
tlist=tlist,
for_immutable_state,
for_mutable_operator,
for_mutable_state,
for_expval,
atol,
Expand Down
13 changes: 9 additions & 4 deletions src/propagator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ implement the following interface.
* `backward`: Boolean flag to indicate whether the propagation moves forward or
backward in time
* `inplace`: Boolean flag to indicate whether `propagator.state` is modified
in-place or is recreated by every call of `prop_step!` or `set_state!`.
in-place or is recreated by every call of `prop_step!` or `set_state!`. With
`inplace=false`, the propagator should generally avoid in-place operations,
such as calls to [`QuantumPropagators.Controls.evaluate!`](@ref).

Concrete `Propagator` types may have additional properties or fields, but these
should be considered private.
Expand Down Expand Up @@ -332,8 +334,9 @@ function prop_step! end
set_state!(propagator, state)
```

sets the `propagator.state` property. In order to mutate the current state
after a call to [`prop_step!`](@ref), the following pattern is recommended:
sets the `propagator.state` property and returns `propagator.state`. In order
to mutate the current state after a call to [`prop_step!`](@ref), the following
pattern is recommended:

```
Ψ = propagator.state
Expand Down Expand Up @@ -366,9 +369,11 @@ function set_state!(propagator::AbstractPropagator, state)
if propagator.inplace
copyto!(propagator.state, state)
else
setfield!(propagator, :state, state)
T = typeof(propagator.state)
setfield!(propagator, :state, convert(T, state))
end
end
return propagator.state
end


Expand Down
Loading
Loading