Skip to content

Commit

Permalink
attempt at fixing indexing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Jan 18, 2025
1 parent 1a9fa4f commit c195317
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/RimuIO/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function load_state(::Type{D}, filename; style=nothing, kwargs...) where {D}
end

function load_state(filename; kwargs...)
if Threads.nthreads() == 1
if Threads.nthreads() == 1 && mpi_size() == 1
return load_state(DVec, filename; kwargs...)
else
return load_state(PDVec, filename; kwargs...)
Expand Down
18 changes: 13 additions & 5 deletions src/RimuIO/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@ table from Tables.jl. Constructed with `Tables.table(::DVec)`.
struct DVecAsTable{K,V}
dict::Dict{K,V}
end
function Base.iterate(tbl::DVecAsTable, st=0)
itr = iterate(tbl.dict, st)
function Base.iterate(tbl::DVecAsTable, st=nothing)
if isnothing(st)
itr = iterate(tbl.dict)
else
itr = iterate(tbl.dict, st)
end
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), st
Expand Down Expand Up @@ -38,17 +42,21 @@ table from Tables.jl. Constructed with `Tables.table(::PDVec)`.
struct PDVecAsTable{K,V,N}
segments::NTuple{N,Dict{K,V}}
end
function Base.iterate(tbl::PDVecAsTable, (st,i)=(0, 1))
function Base.iterate(tbl::PDVecAsTable, (st,i)=(nothing, 1))
if i > length(tbl.segments)
return nothing
end

itr = iterate(tbl.segments[i], st)
if isnothing(st)
itr = iterate(tbl.segments[i])
else
itr = iterate(tbl.segments[i], st)
end
if !isnothing(itr)
pair, st = itr
return (; key=pair[1], value=pair[2]), (st, i)
else
return iterate(tbl, (0, i+1))
return iterate(tbl, (nothing, i+1))
end
end

Expand Down
46 changes: 30 additions & 16 deletions test/RimuIO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,41 @@ end
end

@testset "save_state, load_state" begin
file = joinpath(tmpdir, "tmp.arrow")
file = joinpath(tmpdir, "tmp-dvec.arrow")
rm(file; force=true)

@testset "vectors" begin
ham = HubbardReal1D(BoseFS(1,1,1))
dvec = ham * DVec([BoseFS(1,1,1) => 1.0, BoseFS(2,1,0) => π])
save_state(file, dvec)
output, _ = load_state(file)
@test output == dvec
rm(file)
@testset "save DVec" begin
ham = HubbardReal1D(BoseFS(1,1,1))
dvec = ham * DVec([BoseFS(1,1,1) => 1.0, BoseFS(2,1,0) => π])
save_state(file, dvec)
output, _ = load_state(file)
@test output == dvec
rm(file)
end

pdvec = ham * PDVec([BoseFS(1,1,1) => 1.0, BoseFS(0,3,0) => ℯ])
save_state(file, pdvec)
output, _ = load_state(file)
@test output == pdvec
@testset "save PDVec" begin
pdvec = ham * PDVec([BoseFS(1,1,1) => 1.0, BoseFS(0,3,0) => ℯ])
save_state(file, pdvec)
output, _ = load_state(file)
@test output == pdvec

@test load_state(PDVec, file)[1] isa PDVec
@test load_state(PDVec, file)[1] == pdvec
@test load_state(DVec, file)[1] isa DVec
@test load_state(DVec, file)[1] == pdvec
rm(file)
@test load_state(PDVec, file)[1] isa PDVec
@test load_state(PDVec, file)[1] == pdvec
@test load_state(DVec, file)[1] isa DVec
@test load_state(DVec, file)[1] == pdvec
rm(file)
end

@testset "save empty vector" begin
dvec = DVec{Int,Int}()
save_state(file, dvec)
@test isempty(load_state(file)[1])
pdvec = PDVec{Int,Int}()
save_state(file, pdvec)
@test isempty(load_state(file)[1])
rm(file)
end
end

@testset "metadata" begin
Expand Down
16 changes: 3 additions & 13 deletions test/mpi_runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,23 +254,13 @@ end
@mpi_root rm(file; force=true)

@testset "vectors" begin
ham = HubbardReal1D(BoseFS(1,1,1))
dvec = ham * DVec([BoseFS(1,1,1) => 1.0, BoseFS(2,1,0) => π])
save_state(file, dvec)
output, _ = load_state(file)
@test output == dvec

@mpi_root rm(file)
ham = HubbardReal1D(BoseFS(1,1,1,1,1))

pdvec = ham * PDVec([BoseFS(1,1,1) => 1.0, BoseFS(0,3,0) => ℯ])
pdvec = ham * PDVec([BoseFS(1,1,1,1,1) => 1.0, BoseFS(0,3,0,1,1) => ℯ])
save_state(file, pdvec)
output, _ = load_state(file)
@test output == pdvec

@test load_state(PDVec, file)[1] isa PDVec
@test load_state(PDVec, file)[1] == pdvec
@test load_state(DVec, file)[1] isa DVec
@test load_state(DVec, file)[1] == pdvec
@test output isa PDVec

@mpi_root rm(file)
end
Expand Down

0 comments on commit c195317

Please sign in to comment.