Skip to content

Commit

Permalink
Generalize get_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
goerz committed Apr 3, 2024
1 parent 9b2ffc5 commit c2abe41
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 49 deletions.
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Pkg

DOCUMENTER_VERSION = [p for (uuid, p) in Pkg.dependencies() if p.name == "Documenter"][1].version
DOCUMENTER_VERSION =
[p for (uuid, p) in Pkg.dependencies() if p.name == "Documenter"][1].version
if DOCUMENTER_VERSION <= v"1.3.0"
Pkg.develop("Documenter")
end
Expand Down
20 changes: 12 additions & 8 deletions src/controls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ For generators without any explicit time dependence,
op = evaluate(generator; vals_dict)
```
can be used. The `vals_dict` in this case must contina values for all controls
can be used. The `vals_dict` in this case must contain values for all controls
in `generator`.
# See also:
Expand Down Expand Up @@ -544,13 +544,17 @@ the mid-points of the time grid, as obtained by
[`discretize_on_midpoints`](@ref), and `get_parameters` is ignored.
"""
function get_parameters(object)
return _get_parameters(object; via=get_controls)
end

function _get_parameters(object; via=get_controls)
parameter_arrays = []
seen_parameter_array = IdDict{Any,Bool}()
seen_control = IdDict{Any,Bool}()
for control in get_controls(object)
if control !== object
if !haskey(seen_control, control)
parameter_array = get_parameters(control)
seen_component = IdDict{Any,Bool}()
for component in via(object)
if component !== object # E.g., get_controls(control) -> control
if !haskey(seen_component, component)
parameter_array = get_parameters(component)
if isempty(parameter_array)
continue
end
Expand All @@ -559,7 +563,7 @@ function get_parameters(object)
seen_parameter_array[parameter_array] = true
end
end
seen_control[control] = true
seen_component[component] = true
end
end
if isempty(parameter_arrays)
Expand All @@ -571,7 +575,7 @@ function get_parameters(object)
return _combine_parameter_arrays(parameter_arrays)
catch exception
if exception isa MethodError
msg = "In order for parameter arrays to be combined from multiple controls, the `RecursiveArrayTools` package must be loaded"
msg = "In order for parameter arrays to be combined from multiple components, the `RecursiveArrayTools` package must be loaded"
@error msg exception
else
rethrow()
Expand Down
233 changes: 193 additions & 40 deletions src/interfaces/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ..Generators: Generator
```julia
@test check_generator(generator; state, tlist,
for_mutable_state=true, for_immutable_state=true,
for_pwc=true, for_time_continuous=false,
for_expval=true, for_parameterization=false,
atol=1e-14, quiet=false)
```
Expand All @@ -16,15 +17,25 @@ verifies the given `generator`:
tuple
* all controls returned by [`get_controls(generator)`](@ref get_controls) must
pass [`check_control`](@ref)
* [`evaluate(generator, tlist, n)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* [`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined
* [`substitute(generator, replacements)`](@ref substitute) must be defined
* If `generator` is a [`Generator`](@ref) instance, all elements of
`generator.amplitudes` must pass [`check_amplitude`](@ref) with
`for_parameterization`.
If `for_pwc` (default):
* [`evaluate(generator, tlist, n)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* [`evaluate!(op, generator, tlist, n)`](@ref evaluate!) must be defined
If `for_time_continuous`:
* [`evaluate(generator, t)`](@ref evaluate) must return a valid
operator ([`check_operator`](@ref)), with forwarded keyword arguments
(including `for_expval`)
* [`evaluate!(op, generator, t)`](@ref evaluate!) must be defined
If `for_parameterization` (may require the `RecursiveArrayTools` package to be
loaded):
Expand All @@ -43,6 +54,8 @@ function check_generator(
for_mutable_state=true,
for_immutable_state=true,
for_expval=true,
for_pwc=true,
for_time_continuous=false,
for_parameterization=false,
atol=1e-14,
quiet=false,
Expand Down Expand Up @@ -76,42 +89,6 @@ function check_generator(
success &= check_parameterized(generator; _message_prefix=px)
end

try
op = evaluate(generator, tlist, 1)
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
_message_prefix="On `op = evaluate(generator, tlist, 1)`: "
)
quiet ||
@error "$(px)`evaluate(generator, tlist, n)` must return an operator that passes `check_operator`"
success = false
end
catch exc
quiet || @error(
"$(px)`evaluate(generator, tlist, n)` must return a valid operator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
controls = get_controls(generator)
for (i, control) enumerate(controls)
Expand Down Expand Up @@ -156,6 +133,182 @@ function check_generator(
success = false
end

vals_dict = IdDict()
try
if success
if for_pwc
vals_dict = IdDict(
control => evaluate(control, tlist, 1) for
control in get_controls(generator)
)
elseif for_time_continuous
vals_dict = IdDict(
control => evaluate(control, tlist[1]) for
control in get_controls(generator)
)
end
end
catch exc
quiet || @error(
"$(px)`evaluate(control, …)` must be defined for all controls in generator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

if for_pwc

try
op = evaluate(generator, tlist, 1)
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
_message_prefix="On `op = evaluate(generator, tlist, 1)`: "
)
quiet ||
@error "$(px)`evaluate(generator, tlist, n)` must return an operator that passes `check_operator`"
success = false
end
catch exc
quiet || @error(
"$(px)`evaluate(generator, tlist, n)` must return a valid operator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist, 1; vals_dict)
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
_message_prefix="On `op = evaluate(generator, tlist, 1; vals_dict)`: "
)
quiet ||
@error "$(px)`evaluate(generator, tlist, n; vals_dict)` must return an operator that passes `check_operator`"
success = false
end
catch exc
quiet || @error(
"$(px)`evaluate(generator, tlist, n; vals_dict)` must return a valid operator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist, 1)
evaluate!(op, generator, tlist, length(tlist) - 1; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, tlist, n; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

end

if for_time_continuous

try
op = evaluate(generator, tlist[1])
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
_message_prefix="On `op = evaluate(generator, tlist[1])`: "
)
quiet ||
@error "$(px)`evaluate(generator, t)` must return an operator that passes `check_operator`"
success = false
end
catch exc
quiet || @error(
"$(px)`evaluate(generator, t)` must return a valid operator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist[1]; vals_dict)
if !check_operator(
op;
state,
tlist,
for_mutable_state,
for_immutable_state,
for_expval,
atol,
quiet,
_message_prefix="On `op = evaluate(generator, tlist[1]; vals_dict)`: "
)
quiet ||
@error "$(px)`evaluate(generator, t; vals_dict)` must return an operator that passes `check_operator`"
success = false
end
catch exc
quiet || @error(
"$(px)`evaluate(generator, t; vals_dict)` must return a valid operator.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end])
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

try
op = evaluate(generator, tlist[begin])
evaluate!(op, generator, tlist[end]; vals_dict)
catch exc
quiet || @error(
"$(px)`evaluate!(op, generator, t; vals_dict)` must be defined.",
exception = (exc, catch_abbreviated_backtrace())
)
success = false
end

end


if (generator isa Generator) && _check_amplitudes
try
for (i, ampl) in enumerate(generator.amplitudes)
Expand Down

0 comments on commit c2abe41

Please sign in to comment.