Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement helper function to process artificial star tests #37

Merged
merged 14 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[extensions]
TypedTablesExt = "TypedTables"
DataFramesExt = "DataFrames"

[compat]
DataFrames = "1"
Distributions = "0.25"
Documenter = "1"
DynamicHMC = "3.4.2" # Changed stack_posterior_matrices order https://github.com/tpapp/DynamicHMC.jl/pull/175
Expand All @@ -36,6 +46,7 @@ LogDensityProblems = "1, 2"
LoopVectorization = "0.12"
MCMCChains = "6"
Optim = "1.7" # Inverse Hessian estimate from BFGS
Printf = "<0.0.1, 1"
QuadGK = "2"
Random = "<0.0.1, 1"
Roots = "2"
Expand All @@ -45,9 +56,11 @@ StableRNGs = "1"
StaticArrays = "1"
StatsBase = "0.32, 0.33, 0.34"
Test = "<0.0.1, 1"
TypedTables = "1"
julia = "1.7"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand All @@ -60,6 +73,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["Distributions", "Documenter", "DynamicHMC", "InitialMassFunctions", "LinearAlgebra", "MCMCChains", "QuadGK", "Random", "SafeTestsets", "StableRNGs", "StaticArrays", "Test"]
test = ["DataFrames", "Distributions", "Documenter", "DynamicHMC", "InitialMassFunctions", "LinearAlgebra", "MCMCChains", "QuadGK", "Random", "SafeTestsets", "StableRNGs", "StaticArrays", "Test", "TypedTables"]
1 change: 1 addition & 0 deletions docs/src/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ StarFormationHistories.mdf_amr
```@docs
StarFormationHistories.Martin2016_complete
StarFormationHistories.exp_photerr
StarFormationHistories.process_ASTs
```
59 changes: 59 additions & 0 deletions ext/DataFramesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module DataFramesExt

import StarFormationHistories: process_ASTs
using Printf: @sprintf
using DataFrames: DataFrame
using StatsBase: median, mean

function process_ASTs(ASTs::DataFrame, inmag::Symbol, outmag::Symbol,
bins::AbstractVector{<:Real}, selectfunc;
statistic=median)
@assert length(bins) > 1
!issorted(bins) && sort!(bins)

completeness = Vector{Float64}(undef, length(bins)-1)
bias = similar(completeness)
error = similar(completeness)
bin_centers = similar(completeness)

input_mags = getproperty(ASTs, inmag)

Threads.@threads for i in eachindex(completeness)
# Get the stars in the current bin
inbin = findall((input_mags .>= bins[i]) .&
(input_mags .< bins[i+1]))
tmp_asts = ASTs[inbin,:]
if size(tmp_asts,1) == 0
@warn(@sprintf("No input magnitudes found in bin ranging from %.6f => %.6f \
in `ASTs.inmag`, please revise `bins` argument.", bins[i],
bins[i+1]))
completeness[i] = NaN
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
continue
end
# Let selectfunc determine which ASTs are properly detected
good = [selectfunc(row) for row in eachrow(tmp_asts)]
completeness[i] = count(good) / size(tmp_asts,1)
if count(good) > 0
inmags = getproperty(tmp_asts, inmag)[good]
outmags = getproperty(tmp_asts, outmag)[good]
diff = outmags .- inmags # This makes bias relative to input
bias[i] = statistic(diff)
error[i] = statistic(abs.(diff))
bin_centers[i] = mean(inmags)
else
@warn(@sprintf("Completeness measured to be 0 in bin ranging from \
%.6f => %.6f. The error and bias values for this bin \
will be returned as NaN.", bins[i], bins[i+1]))
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
end
end
return bin_centers, completeness, bias, error
end


end
59 changes: 59 additions & 0 deletions ext/TypedTablesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module TypedTablesExt

import StarFormationHistories: process_ASTs
using Printf: @sprintf
using TypedTables: Table
using StatsBase: median, mean

function process_ASTs(ASTs::Table, inmag::Symbol, outmag::Symbol,
bins::AbstractVector{<:Real}, selectfunc;
statistic=median)
@assert length(bins) > 1
!issorted(bins) && sort!(bins)

completeness = Vector{Float64}(undef, length(bins)-1)
bias = similar(completeness)
error = similar(completeness)
bin_centers = similar(completeness)

input_mags = getproperty(ASTs, inmag)

Threads.@threads for i in eachindex(completeness)
# Get the stars in the current bin
inbin = findall((input_mags .>= bins[i]) .&
(input_mags .< bins[i+1]))
tmp_asts = ASTs[inbin]
if length(tmp_asts) == 0
@warn(@sprintf("No input magnitudes found in bin ranging from %.6f => %.6f \
in `ASTs.inmag`, please revise `bins` argument.", bins[i],
bins[i+1]))
completeness[i] = NaN
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
continue
end
# Let selectfunc determine which ASTs are properly detected
good = selectfunc.(tmp_asts)
completeness[i] = count(good) / length(tmp_asts)
if count(good) > 0
inmags = getproperty(tmp_asts, inmag)[good]
outmags = getproperty(tmp_asts, outmag)[good]
diff = outmags .- inmags # This makes bias relative to input
bias[i] = statistic(diff)
error[i] = statistic(abs.(diff))
bin_centers[i] = mean(inmags)
else
@warn(@sprintf("Completeness measured to be 0 in bin ranging from \
%.6f => %.6f. The error and bias values for this bin \
will be returned as NaN.", bins[i], bins[i+1]))
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
end
end
return bin_centers, completeness, bias, error
end


end
25 changes: 14 additions & 11 deletions src/StarFormationHistories.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
module StarFormationHistories

import Distributions: Distribution, Sampleable, Univariate, Continuous, pdf, logpdf,
quantile, Multivariate, MvNormal, _rand!, sampler, Uniform # cdf
using Distributions: Distribution, Sampleable, Univariate, Continuous, pdf, logpdf,
quantile, Multivariate, MvNormal, sampler, Uniform # cdf
import Distributions: _rand! # Extending
import DynamicHMC # For random uncertainties in SFH fits
import Interpolations: interpolate, Gridded, Linear, deduplicate_knots! # extrapolate, Throw
using Interpolations: interpolate, Gridded, Linear, deduplicate_knots! # extrapolate, Throw
import LBFGSB # Used for one method in fitting.jl
import LineSearches # For configuration of Optim.jl
# Need mul! for composite!, ∇loglikelihood!;
import LinearAlgebra: diag, Hermitian, mul!
using LinearAlgebra: diag, Hermitian, mul!
import LogDensityProblems # For interfacing with DynamicHMC
import LoopVectorization: @turbo
using LoopVectorization: @turbo
import Optim
import QuadGK: quadgk # For general mean(imf::UnivariateDistribution{Continuous}; kws...)
import Random: AbstractRNG, default_rng, rand
import Roots: find_zero # For mass_limits in simulate.jl
import SpecialFunctions: erf
import StaticArrays: SVector, SMatrix, sacollect
import StatsBase: fit, Histogram, Weights, sample, mean
using Printf: @sprintf
using QuadGK: quadgk # For general mean(imf::UnivariateDistribution{Continuous}; kws...)
using Random: AbstractRNG, default_rng, rand
using Roots: find_zero # For mass_limits in simulate.jl
using SpecialFunctions: erf
using StaticArrays: SVector, SMatrix, sacollect
using StatsBase: fit, Histogram, Weights, sample, median
import StatsBase: mean # Extending
import KissMCMC
import MCMCChains

Expand Down
46 changes: 46 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,52 @@ Reported values for some HST data were `a=1.05, b=10.0, c=32.0, d=0.01`.
"""
exp_photerr(m, a, b, c, d) = a^(b * (m-c)) + d

"""
process_ASTs(ASTs::Union{DataFrames.DataFrame,
TypedTables.Table},
inmag::Symbol,
outmag::Symbol,
bins::AbstractVector{<:Real},
selectfunc;
statistic=StatsBase.median)
Processes a table of artificial stars to calculate photometric completeness, bias, and error across the provided `bins`. This method has no default implementation and is implemented in package extensions that rely on either `DataFrames.jl` or `TypedTables.jl` being loaded into your Julia session to load the relevant method. This method therefore requires Julia 1.9 or greater to use.
# Arguments
- `ASTs` is the table of artificial stars to be analyzed.
- `inmag` is the column name in symbol format (e.g., :F606Wi) that corresponds to the intrinsic (input) magnitudes of the artificial stars.
- `outmag` is the column name in symbol format (e.g., :F606Wo) that corresponds to the measured (output) magnitude of the artificial stars.
- `bins` give the bin edges to be used when computing the binned statistics.
- `selectfunc` is a method that takes a single row from `ASTs`, corresponding to a single artificial star, and returns a boolean that is `true` if the star is considered successfully measured.
# Keyword Arguments
- `statistic` is the method that will be used to determine the bias and error, i.e., `bias = statistic(out .- in)` and `error = statistic(abs.(out .- in))`. By default we use `StatsBase.median`, but you could instead use a simple or sigma-clipped mean if so desired.
# Returns
This method returns a `result` of type `NTuple{4,Vector{Float64}}`. Each vector is of length `length(bins)-1`. `result` contains the following elements, each of which are computed over the provided `bins` considering only artificial stars for which `selectfunc` returned `true`:
- `result[1]` contains the mean input magnitude of the stars in each bin.
- `result[2]` contains the completeness value measured for each bin, defined as the fraction of input stars in each bin for which `selectfunc` returned `true`.
- `result[3]` contains the photometric bias measured for each bin, defined as `statistic(out .- in)`, where `out` are the measured (output) magnitudes and `in` are the intrinsic (input) magnitudes.
- `result[4]` contains the photometric error measured for each bin, defined as `statistic(abs.(out .- in))`, with `out` and `in` defined as above.
# Examples
Let
- `F606Wi` be a vector containing the input magnitudes of your artificial stars
- `F606Wo` be a vector containing the measured magnitudes of the artificial stars, where a value of 99.999 indicates a non-detection.
- `flag` be a vector of booleans that indicates whether the artificial star passed additional quality cuts (star-galaxy separation, etc.)
You could call this method as
```julia
import TypedTables: Table
process_ASTs(Table(input=F606Wi, output=F606Wo, good=flag),
:input, :output, minimum(F606Wi):0.1:maximum(F606Wi),
x -> (x.good==true) & (x.output != 99.999))
```
See also the tests in `test/utilities/process_ASTs_test.jl`.
"""
function process_ASTs end

# Numerical utilities

# function estimate_mode(data::AbstractVector{<:Real})
Expand Down
4 changes: 2 additions & 2 deletions test/fitting/basic_linear_combinations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Distributions: Poisson
import StableRNGs: StableRNG
using Distributions: Poisson
using StableRNGs: StableRNG
import StarFormationHistories as SFH
using Test

Expand Down
8 changes: 4 additions & 4 deletions test/fitting/linear_amr_test.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import Distributions: Poisson
using Distributions: Poisson
import DynamicHMC
import StableRNGs: StableRNG
using StableRNGs: StableRNG
import StarFormationHistories as SFH

import LinearAlgebra: Diagonal
import Random: rand!
using LinearAlgebra: Diagonal
using Random: rand!
using Test


Expand Down
6 changes: 3 additions & 3 deletions test/fitting/log_amr_test.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import Distributions: Poisson
using Distributions: Poisson
# import ForwardDiff
import StarFormationHistories as SFH
import StableRNGs: StableRNG
import Random: rand!
using StableRNGs: StableRNG
using Random: rand!
using Test

# Now try fixed_log_amr that uses an AMR that is logarithmic in [M/H]
Expand Down
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import StarFormationHistories as SFH
import InitialMassFunctions: Salpeter1955, Kroupa2001
import Distributions: Poisson, Uniform, pdf, median
using InitialMassFunctions: Salpeter1955, Kroupa2001
using Distributions: Poisson, Uniform, pdf, median
import Random
import StableRNGs: StableRNG
import StaticArrays: SVector
import QuadGK: quadgk
using StableRNGs: StableRNG
using StaticArrays: SVector
using QuadGK: quadgk
import MCMCChains
import DynamicHMC
# import Optim
Expand Down Expand Up @@ -681,6 +681,11 @@ const rtols = (1e-3, 1e-7) # Relative tolerance levels to use for the above floa


@testset "utilities" begin
# Uses extensions, requires Julia >= 1.9
if VERSION >= v"1.9"
@safetestset "process_ASTs" include("utilities/process_ASTs_test.jl")
end

for i in eachindex(float_types, float_type_labels)
label = float_type_labels[i]
@testset "$label" begin
Expand Down
Loading