Skip to content

Commit

Permalink
Allow passing time_step to StatsTools functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsch committed Jan 6, 2025
1 parent 5f48b35 commit 9b4dd0f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/StatsTools/growth_witness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function growth_witness(
shift=:shift, norm=:norm, time_step=nothing, kwargs...
)
df = DataFrame(sim)
time_step = determine_constant_time_step(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step

shift_vec = getproperty(df, Symbol(shift))
norm_vec = getproperty(df, Symbol(norm))
Expand Down
16 changes: 10 additions & 6 deletions src/StatsTools/reweighting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Given a `DataFrame` `df`, determine the time step that was used to compute it. Throw an
error if time step is not constant.
"""
function determine_constant_time_step(df)
function determine_constant_time_step(df, default)
# Using get for backwards compatibility with old data frames
if get(metadata(df), "time_step_strategy", "ConstantTimeStep()") == "ConstantTimeStep()"
if haskey(metadata(df), "time_step")
Expand All @@ -20,7 +20,7 @@ function determine_constant_time_step(df)
elseif hasproperty(df, "dτ_1")
return df.dτ_1[end]
else
throw(ArgumentError("Time step not found in `df`"))
throw(ArgumentError("key `\"time_step\"` not found in `df` metadata"))
end
else
throw(ArgumentError("Time step not constant"))
Expand Down Expand Up @@ -243,13 +243,14 @@ function growth_estimator_analysis(
shift_name=:shift,
norm_name=:norm,
warn=true,
time_step=nothing,
kwargs...
)
df = DataFrame(sim)
shift_v = Vector(getproperty(df, Symbol(shift_name))) # casting to `Vector` to make SIMD loops efficient
norm_v = Vector(getproperty(df, Symbol(norm_name)))
num_reps = length(filter(startswith("norm"), names(df)))
time_step = determine_constant_time_step(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
se = blocking_analysis(shift_v; skip)
E_r = se.mean
correlation_estimate = 2^(se.k - 1)
Expand Down Expand Up @@ -412,14 +413,15 @@ function mixed_estimator_analysis(
hproj_name=:hproj,
vproj_name=:vproj,
warn=true,
time_step=nothing,
kwargs...
)
shift_v = Vector(getproperty(df, Symbol(shift_name))) # casting to `Vector` to make SIMD loops efficient
hproj_v = Vector(getproperty(df, Symbol(hproj_name)))
vproj_v = Vector(getproperty(df, Symbol(vproj_name)))
num_reps = length(filter(startswith("norm"), names(df)))

time_step = determine_constant_time_step(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
se = blocking_analysis(shift_v; skip)
E_r = se.mean
correlation_estimate = 2^(se.k - 1)
Expand Down Expand Up @@ -555,11 +557,12 @@ function rayleigh_replica_estimator(
h=0,
skip=0,
Anorm=1,
time_step=nothing,
kwargs...
)
df = DataFrame(sim)
num_reps = length(filter(startswith("norm"), names(df)))
time_step = determine_constant_time_step(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
T = eltype(df[!, Symbol(shift_name, "_1")])
shift_v = Vector{T}[]
for a in 1:num_reps
Expand Down Expand Up @@ -625,11 +628,12 @@ function rayleigh_replica_estimator_analysis(
vec_name="dot",
Anorm=1,
warn=true,
time_step=nothing,
kwargs...
)
df = DataFrame(sim)
num_reps = length(filter(startswith("norm"), names(df)))
time_step = determine_constant_time_step(df)
time_step = isnothing(time_step) ? determine_constant_time_step(df) : time_step
# estimate the correlation time by blocking the shift data
T = eltype(df[!, Symbol(shift_name, "_1")])
shift_v = Vector{T}[]
Expand Down

0 comments on commit 9b4dd0f

Please sign in to comment.