Skip to content

Commit

Permalink
Merge pull request #120 from QuanEstimation/test-adpt-online
Browse files Browse the repository at this point in the history
update tests for adaptmzi online
  • Loading branch information
hmyuuu authored Jan 21, 2025
2 parents 74334dd + 46eafaf commit 44401a3
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 88 deletions.
4 changes: 2 additions & 2 deletions lib/QuanEstimationBase/ext/QuanEstimationBasePyExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ function QuanEstimationBase.ode_py(
ctrl_num = length(Hc)
ctrl_interval = ((length(tspan) - 1) / length(ctrl[1])) |> Int
ctrl =
[repeat(ctrl[i], 1, ctrl_interval) |> transpose |> vec |> Array for i = 1:ctrl_num]
[repeat_copy(ctrl[i], 1, ctrl_interval) |> transpose |> vec |> Array for i = 1:ctrl_num]
push!.(ctrl, [0.0 for i = 1:ctrl_num])
H(ctrl) = Htot(H0, Hc, ctrl)
dt = tspan[2] - tspan[1]
Expand Down Expand Up @@ -252,7 +252,7 @@ function QuanEstimationBase.ode_py(
ctrl_num = length(Hc)
ctrl_interval = ((length(tspan) - 1) / length(ctrl[1])) |> Int
ctrl =
[repeat(ctrl[i], 1, ctrl_interval) |> transpose |> vec |> Array for i = 1:ctrl_num]
[repeat_copy(ctrl[i], 1, ctrl_interval) |> transpose |> vec |> Array for i = 1:ctrl_num]
push!.(ctrl, [0.0 for i = 1:ctrl_num])
H(ctrl) = Htot(H0, Hc, ctrl)
dt = tspan[2] - tspan[1]
Expand Down
12 changes: 6 additions & 6 deletions lib/QuanEstimationBase/src/Algorithm/DE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function optimize!(opt::ControlOpt, alg::DE, obj, scheme, output)
ini_population = ini_population[1]
ctrl_length = get_ctrl_length(scheme)
ctrl_num = get_ctrl_num(scheme)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)

# initialization
initial_ctrl!(opt, ini_population, populations, p_num, opt.rng)
Expand Down Expand Up @@ -84,7 +84,7 @@ function optimize!(opt::StateOpt, alg::DE, obj, scheme, output)
end
ini_population = ini_population[1]
dim = get_dim(scheme)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)
# initialization
initial_state!(ini_population, populations, p_num, opt.rng)

Expand Down Expand Up @@ -426,7 +426,7 @@ function optimize!(opt::StateControlOpt, alg::DE, obj, scheme, output)
ctrl_length = get_ctrl_length(scheme)
ctrl_num = get_ctrl_num(scheme)
dim = get_dim(scheme)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)

# initialization
initial_state!(psi0, populations, p_num, opt.rng)
Expand Down Expand Up @@ -533,7 +533,7 @@ function optimize!(opt::StateMeasurementOpt, alg::DE, obj, scheme, output)
psi0, measurement0 = ini_population
dim = get_dim(scheme)
M_num = length(opt.M)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)

# initialization
initial_state!(psi0, populations, p_num, opt.rng)
Expand Down Expand Up @@ -645,7 +645,7 @@ function optimize!(opt::ControlMeasurementOpt, alg::DE, obj, scheme, output)
ctrl_num = get_ctrl_num(scheme)

M_num = length(opt.M)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)

# initialization
initial_ctrl!(opt, ctrl0, populations, p_num, opt.rng)
Expand Down Expand Up @@ -763,7 +763,7 @@ function optimize!(opt::StateControlMeasurementOpt, alg::DE, obj, scheme, output
ctrl_length = get_ctrl_length(scheme)
ctrl_num = get_ctrl_num(scheme)
M_num = length(opt.M)
populations = repeat(scheme, p_num)
populations = repeat_copy(scheme, p_num)

# initialization
initial_state!(psi0, populations, p_num, opt.rng)
Expand Down
2 changes: 1 addition & 1 deletion lib/QuanEstimationBase/src/Algorithm/NM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function optimize!(opt::StateOpt, alg::NM, obj, scheme, output)
ini_state = [opt.psi]
end
dim = get_dim(scheme)
nelder_mead = repeat(scheme, p_num)
nelder_mead = repeat_copy(scheme, p_num)

# initialize
if length(ini_state) > p_num
Expand Down
18 changes: 9 additions & 9 deletions lib/QuanEstimationBase/src/Algorithm/PSO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function optimize!(opt::ControlOpt, alg::PSO, obj, scheme, output)
ini_particle = ini_particle[1]
ctrl_length = get_ctrl_length(scheme)
ctrl_num = get_ctrl_num(scheme)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -94,7 +94,7 @@ function optimize!(opt::ControlOpt, alg::PSO, obj, scheme, output)
end
if ei % max_episode[2] == 0
pdata.ctrl = [gbest[k, :] for k = 1:ctrl_num]
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)
end

set_f!(output, fit_out)
Expand All @@ -116,7 +116,7 @@ function optimize!(opt::StateOpt, alg::PSO, obj, scheme, output)
end
ini_particle = ini_particle[1]
dim = get_dim(scheme)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -178,7 +178,7 @@ function optimize!(opt::StateOpt, alg::PSO, obj, scheme, output)
end
if ei % max_episode[2] == 0
sdata = [gbest[i] for i = 1:dim]
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)
end
set_f!(output, fit_out)
set_buffer!(output, [gbest[i] for i = 1:dim])
Expand Down Expand Up @@ -407,7 +407,7 @@ function optimize!(opt::Mopt_Rotation, alg::PSO, obj, scheme, output)
# append!(Lambda, [suN[i] for i in eachindex(suN)])
# end

particles = repeat(s, p_num)
particles = repeat_copy(s, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -502,7 +502,7 @@ function optimize!(opt::StateControlOpt, alg::PSO, obj, scheme, output)
dim = get_dim(scheme)
ctrl_length = get_ctrl_length(scheme)
ctrl_num = get_ctrl_num(scheme)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -628,7 +628,7 @@ function optimize!(opt::StateMeasurementOpt, alg::PSO, obj, scheme, output)
psi0, measurement0 = ini_particle
dim = get_dim(scheme)
M_num = length(opt.M)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -750,7 +750,7 @@ function optimize!(opt::ControlMeasurementOpt, alg::PSO, obj, scheme, output)
ctrl_num = get_ctrl_num(scheme)
dim = get_dim(scheme)
M_num = length(opt.M)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down Expand Up @@ -888,7 +888,7 @@ function optimize!(opt::StateControlMeasurementOpt, alg::PSO, obj, scheme, outpu
ctrl_num = get_ctrl_num(scheme)
dim = get_dim(scheme)
M_num = length(opt.M)
particles = repeat(scheme, p_num)
particles = repeat_copy(scheme, p_num)

if typeof(max_episode) == Int
max_episode = [max_episode, max_episode]
Expand Down
16 changes: 3 additions & 13 deletions lib/QuanEstimationBase/src/Common/Common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
include("mintime.jl")
# include("mintime.jl")
include("BayesEstimation.jl")

destroy(N) = diagm(1 => [sqrt(n) + 0.0im for n = 1:N-1])
Expand All @@ -17,18 +17,8 @@ function vec2mat(x)
vec2mat.(x)
end

function vec2mat(x::Matrix)
throw(ErrorException("vec2mating a matrix of size $(size(x))"))
end

unzip(X) = map(x -> getfield.(X, x), fieldnames(eltype(X)))

function Base.repeat(system, N)
[deepcopy(system) for i = 1:N]
end

function Base.repeat(system, M, N)
reshape(repeat(system, M * N), M, N)
function repeat_copy(scheme, N)
[deepcopy(scheme) for _ = 1:N]
end

function filterZeros!(x::Matrix{T}) where {T<:Complex}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ Online adaptive phase estimation in the MZI.
- `target`: Setting the target function for calculating the tunable phase. Options are: "sharpness" and "MI".
- `output`: Choose the output variables. Options are: "phi" and "dphi".
"""
function online(apt::Adapt_MZI; target::Symbol = :sharpness, output::String = "phi", res=nothing)
function online(apt::Adapt_MZI; target::String = "sharpness", output::String = "phi", res=nothing)
(; x, p, rho0) = apt
adaptMZI_online(x, p, rho0, Symbol(output), target; res=res)
adaptMZI_online(x, p, rho0, Symbol(output), Symbol(target); res=res)
end

function adaptMZI_online(x, p, rho0, output, target::Symbol; res=nothing)
function adaptMZI_online(x, p, rho0, output, target; res=nothing)
N = Int(sqrt(size(rho0, 1))) - 1
a = destroy(N + 1) |> sparse
exp_ix = [exp(1.0im * xi) for xi in x]
Expand Down Expand Up @@ -120,9 +120,6 @@ function adaptMZI_online(x, p, rho0, output, target::Symbol; res=nothing)
end
end

adaptMZI_online(x, p, rho0, output::String, target::String) =
adaptMZI_online(x, p, rho0, Symbol(output), Symbol(target))

function calculate_online{sharpness}(x, p, pyx, a_res, a, rho0, N, ei, phi_span, exp_ix)

M_res = zeros(length(phi_span))
Expand Down Expand Up @@ -320,34 +317,6 @@ function DE_deltaphiOpt(
return deltaphi[findmax(p_fit)[2]]
end

DE_deltaphiOpt(
x,
p,
rho0,
comb,
p_num,
ini_population,
c,
cr,
seed::Number,
max_episode,
target::String,
eps,
) = DE_deltaphiOpt(
x,
p,
rho0,
comb,
p_num,
ini_population,
c,
cr,
MersenneTwister(seed),
max_episode,
Symbol(target),
eps,
)

function PSO_deltaphiOpt(
x,
p,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@ function dissipation(Γ::V, γ::Vector{R}, t::Int = 1) where {V<:AbstractVector,
[γ[i] * liouville_dissip(Γ[i]) for i in eachindex(Γ)] |> sum
end

function dissipation(
Γ::V,
γ::Vector{Vector{R}},
t::Int = 1,
) where {V<:AbstractVector,R<:Real}
[γ[i][t] * liouville_dissip(Γ[i]) for i in eachindex(Γ)] |> sum
end

function free_evolution(H0)
-1.0im * liouville_commu(H0)
end

function liouvillian(H::Matrix{T}, decay_opt::AbstractVector, γ, t = 1) where {T<:Complex}
freepart = liouville_commu(H)
dissp = norm(γ) + 1 1 ? freepart |> zero : dissipation(decay_opt, γ, t)
Expand All @@ -33,13 +21,13 @@ function Htot(H0::T, Hc::V, ctrl) where {T<:Matrix{ComplexF64},V<:AbstractVector
[H0 + sum([ctrl[i][t] * Hc[i] for i in eachindex(Hc)]) for t in eachindex(ctrl[1])]
end

function Htot(
H0::T,
Hc::V,
ctrl::Vector{R},
) where {T<:AbstractArray,V<:AbstractVector,R<:Real}
H0 + ([ctrl[i] * Hc[i] for i in eachindex(ctrl)] |> sum)
end
# function Htot(
# H0::T,
# Hc::V,
# ctrl::Vector{R},
# ) where {T<:AbstractArray,V<:AbstractVector,R<:Real}
# H0 + ([ctrl[i] * Hc[i] for i in eachindex(ctrl)] |> sum)
# end


# function Htot(H0::V1, Hc::V2, ctrl) where {V1<:AbstractVector,V2<:AbstractVector}
Expand Down
1 change: 1 addition & 0 deletions test/objective/test_cramer_rao_bound.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ function test_cramer_rao_bound_multi_param()
QFIM_LLD(rho[end], drho[end][1])
QFIM_pure(rho0, [zero(rho0) for _ = 1:2])
QFIM_pure(rho0, zero(rho0))
NHB(rho[end], drho[end], one(zeros(2, 2)))

@test all([tr(pinv(i)) >= 0 for i in Im])
@test all([tr(pinv(f)) >= 0 for f in F])
Expand Down
4 changes: 2 additions & 2 deletions test/optimization/test_state_optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ function test_sopt_qfim(; savefile = false)
end

function test_sopt_cfi(; savefile = false)
(; tspan, psi, H0, dH, decay) = generate_LMG1_dynamics()
(; tspan, psi, H0, dH) = generate_LMG1_dynamics()

dynamics = Lindblad(H0, dH, tspan, decay; dyn_method = :Expm)
dynamics = Lindblad(H0, dH, tspan; dyn_method = :Expm)
scheme = GeneralScheme(; probe = psi, param = dynamics)

obj = CFIM_obj()
Expand Down
7 changes: 5 additions & 2 deletions test/test_adaptive_estimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ function test_adaptive_estimation_MZI()
apt = Adapt_MZI(x, p, rho0)

#================online strategy=========================#
online(apt; target=:sharpness, output="phi", res=zeros(2))
online(apt; target="sharpness", output="phi", res=zeros(2))
online(apt; target="MI", output="phi", res=zeros(2))
online(apt; target="sharpness", output="dphi", res=zeros(2))
online(apt; target="MI", output="dphi", res=zeros(2))

#================offline strategy=========================#
# algorithm: DE
Expand All @@ -23,7 +26,7 @@ function test_adaptive_estimation_MZI()
offline(apt, alg, target = :MI, seed = 1234)

# # algorithm: PSO
PSO(p_num=3, ini_particle=nothing, max_episode=[10,10], c0=1.0, c1=2.0, c2=2.0)
alg = PSO(p_num=3, ini_particle=nothing, max_episode=[10,10], c0=1.0, c1=2.0, c2=2.0)
offline(apt, alg, target=:sharpness, seed=1234)


Expand Down

0 comments on commit 44401a3

Please sign in to comment.