Skip to content

Commit

Permalink
Merge pull request #37 from cgarling/ast_util_ext
Browse files Browse the repository at this point in the history
Implement helper function `process_ASTs` to process artificial star tests. I wanted to support `DataFrames.DataFrame` and `TypedTables.Table` but couldn't figure out a single implementation that would work on both, so I wrote an implementation for both types in separate package extensions. I don't really like this solution as there is a lot of code duplication but I cannot currently figure out how to iterate a `DataFrame` and a `TypedTable` in the same way. 

Due to this design, there is *no default implementation* in the base package. One of the above packages must be loaded before `process_ASTs` can be called. 

This solution uses package extensions which are only supported on Julia versions 1.9 or higher. I am not backporting these to earlier versions, though it could be done with https://github.com/cjdoris/PackageExtensionCompat.jl. 

Ideally in the long term we could transition this to a single implementation that would work with both types, without having to depend directly on either in the base package.
  • Loading branch information
cgarling authored Jun 14, 2024
2 parents a93aecf + 2dd8f45 commit 7f89d8a
Show file tree
Hide file tree
Showing 11 changed files with 339 additions and 26 deletions.
16 changes: 15 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[extensions]
TypedTablesExt = "TypedTables"
DataFramesExt = "DataFrames"

[compat]
DataFrames = "1"
Distributions = "0.25"
Documenter = "1"
DynamicHMC = "3.4.2" # Changed stack_posterior_matrices order https://github.com/tpapp/DynamicHMC.jl/pull/175
Expand All @@ -36,6 +46,7 @@ LogDensityProblems = "1, 2"
LoopVectorization = "0.12"
MCMCChains = "6"
Optim = "1.7" # Inverse Hessian estimate from BFGS
Printf = "<0.0.1, 1"
QuadGK = "2"
Random = "<0.0.1, 1"
Roots = "2"
Expand All @@ -45,9 +56,11 @@ StableRNGs = "1"
StaticArrays = "1"
StatsBase = "0.32, 0.33, 0.34"
Test = "<0.0.1, 1"
TypedTables = "1"
julia = "1.7"

[extras]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand All @@ -60,6 +73,7 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"

[targets]
test = ["Distributions", "Documenter", "DynamicHMC", "InitialMassFunctions", "LinearAlgebra", "MCMCChains", "QuadGK", "Random", "SafeTestsets", "StableRNGs", "StaticArrays", "Test"]
test = ["DataFrames", "Distributions", "Documenter", "DynamicHMC", "InitialMassFunctions", "LinearAlgebra", "MCMCChains", "QuadGK", "Random", "SafeTestsets", "StableRNGs", "StaticArrays", "Test", "TypedTables"]
1 change: 1 addition & 0 deletions docs/src/helpers.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ StarFormationHistories.mdf_amr
```@docs
StarFormationHistories.Martin2016_complete
StarFormationHistories.exp_photerr
StarFormationHistories.process_ASTs
```
59 changes: 59 additions & 0 deletions ext/DataFramesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module DataFramesExt

import StarFormationHistories: process_ASTs
using Printf: @sprintf
using DataFrames: DataFrame
using StatsBase: median, mean

function process_ASTs(ASTs::DataFrame, inmag::Symbol, outmag::Symbol,
bins::AbstractVector{<:Real}, selectfunc;
statistic=median)
@assert length(bins) > 1
!issorted(bins) && sort!(bins)

completeness = Vector{Float64}(undef, length(bins)-1)
bias = similar(completeness)
error = similar(completeness)
bin_centers = similar(completeness)

input_mags = getproperty(ASTs, inmag)

Threads.@threads for i in eachindex(completeness)
# Get the stars in the current bin
inbin = findall((input_mags .>= bins[i]) .&
(input_mags .< bins[i+1]))
tmp_asts = ASTs[inbin,:]
if size(tmp_asts,1) == 0
@warn(@sprintf("No input magnitudes found in bin ranging from %.6f => %.6f \
in `ASTs.inmag`, please revise `bins` argument.", bins[i],
bins[i+1]))
completeness[i] = NaN
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
continue
end
# Let selectfunc determine which ASTs are properly detected
good = [selectfunc(row) for row in eachrow(tmp_asts)]
completeness[i] = count(good) / size(tmp_asts,1)
if count(good) > 0
inmags = getproperty(tmp_asts, inmag)[good]
outmags = getproperty(tmp_asts, outmag)[good]
diff = outmags .- inmags # This makes bias relative to input
bias[i] = statistic(diff)
error[i] = statistic(abs.(diff))
bin_centers[i] = mean(inmags)
else
@warn(@sprintf("Completeness measured to be 0 in bin ranging from \
%.6f => %.6f. The error and bias values for this bin \
will be returned as NaN.", bins[i], bins[i+1]))
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
end
end
return bin_centers, completeness, bias, error
end


end
59 changes: 59 additions & 0 deletions ext/TypedTablesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
module TypedTablesExt

import StarFormationHistories: process_ASTs
using Printf: @sprintf
using TypedTables: Table
using StatsBase: median, mean

function process_ASTs(ASTs::Table, inmag::Symbol, outmag::Symbol,
bins::AbstractVector{<:Real}, selectfunc;
statistic=median)
@assert length(bins) > 1
!issorted(bins) && sort!(bins)

completeness = Vector{Float64}(undef, length(bins)-1)
bias = similar(completeness)
error = similar(completeness)
bin_centers = similar(completeness)

input_mags = getproperty(ASTs, inmag)

Threads.@threads for i in eachindex(completeness)
# Get the stars in the current bin
inbin = findall((input_mags .>= bins[i]) .&
(input_mags .< bins[i+1]))
tmp_asts = ASTs[inbin]
if length(tmp_asts) == 0
@warn(@sprintf("No input magnitudes found in bin ranging from %.6f => %.6f \
in `ASTs.inmag`, please revise `bins` argument.", bins[i],
bins[i+1]))
completeness[i] = NaN
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
continue
end
# Let selectfunc determine which ASTs are properly detected
good = selectfunc.(tmp_asts)
completeness[i] = count(good) / length(tmp_asts)
if count(good) > 0
inmags = getproperty(tmp_asts, inmag)[good]
outmags = getproperty(tmp_asts, outmag)[good]
diff = outmags .- inmags # This makes bias relative to input
bias[i] = statistic(diff)
error[i] = statistic(abs.(diff))
bin_centers[i] = mean(inmags)
else
@warn(@sprintf("Completeness measured to be 0 in bin ranging from \
%.6f => %.6f. The error and bias values for this bin \
will be returned as NaN.", bins[i], bins[i+1]))
bias[i] = NaN
error[i] = NaN
bin_centers[i] = bins[i] + (bins[i+1] - bins[i])/2
end
end
return bin_centers, completeness, bias, error
end


end
25 changes: 14 additions & 11 deletions src/StarFormationHistories.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
module StarFormationHistories

import Distributions: Distribution, Sampleable, Univariate, Continuous, pdf, logpdf,
quantile, Multivariate, MvNormal, _rand!, sampler, Uniform # cdf
using Distributions: Distribution, Sampleable, Univariate, Continuous, pdf, logpdf,
quantile, Multivariate, MvNormal, sampler, Uniform # cdf
import Distributions: _rand! # Extending
import DynamicHMC # For random uncertainties in SFH fits
import Interpolations: interpolate, Gridded, Linear, deduplicate_knots! # extrapolate, Throw
using Interpolations: interpolate, Gridded, Linear, deduplicate_knots! # extrapolate, Throw
import LBFGSB # Used for one method in fitting.jl
import LineSearches # For configuration of Optim.jl
# Need mul! for composite!, ∇loglikelihood!;
import LinearAlgebra: diag, Hermitian, mul!
using LinearAlgebra: diag, Hermitian, mul!
import LogDensityProblems # For interfacing with DynamicHMC
import LoopVectorization: @turbo
using LoopVectorization: @turbo
import Optim
import QuadGK: quadgk # For general mean(imf::UnivariateDistribution{Continuous}; kws...)
import Random: AbstractRNG, default_rng, rand
import Roots: find_zero # For mass_limits in simulate.jl
import SpecialFunctions: erf
import StaticArrays: SVector, SMatrix, sacollect
import StatsBase: fit, Histogram, Weights, sample, mean
using Printf: @sprintf
using QuadGK: quadgk # For general mean(imf::UnivariateDistribution{Continuous}; kws...)
using Random: AbstractRNG, default_rng, rand
using Roots: find_zero # For mass_limits in simulate.jl
using SpecialFunctions: erf
using StaticArrays: SVector, SMatrix, sacollect
using StatsBase: fit, Histogram, Weights, sample, median
import StatsBase: mean # Extending
import KissMCMC
import MCMCChains

Expand Down
46 changes: 46 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,52 @@ Reported values for some HST data were `a=1.05, b=10.0, c=32.0, d=0.01`.
"""
exp_photerr(m, a, b, c, d) = a^(b * (m-c)) + d

"""
process_ASTs(ASTs::Union{DataFrames.DataFrame,
TypedTables.Table},
inmag::Symbol,
outmag::Symbol,
bins::AbstractVector{<:Real},
selectfunc;
statistic=StatsBase.median)
Processes a table of artificial stars to calculate photometric completeness, bias, and error across the provided `bins`. This method has no default implementation and is implemented in package extensions that rely on either `DataFrames.jl` or `TypedTables.jl` being loaded into your Julia session to load the relevant method. This method therefore requires Julia 1.9 or greater to use.
# Arguments
- `ASTs` is the table of artificial stars to be analyzed.
- `inmag` is the column name in symbol format (e.g., :F606Wi) that corresponds to the intrinsic (input) magnitudes of the artificial stars.
- `outmag` is the column name in symbol format (e.g., :F606Wo) that corresponds to the measured (output) magnitude of the artificial stars.
- `bins` give the bin edges to be used when computing the binned statistics.
- `selectfunc` is a method that takes a single row from `ASTs`, corresponding to a single artificial star, and returns a boolean that is `true` if the star is considered successfully measured.
# Keyword Arguments
- `statistic` is the method that will be used to determine the bias and error, i.e., `bias = statistic(out .- in)` and `error = statistic(abs.(out .- in))`. By default we use `StatsBase.median`, but you could instead use a simple or sigma-clipped mean if so desired.
# Returns
This method returns a `result` of type `NTuple{4,Vector{Float64}}`. Each vector is of length `length(bins)-1`. `result` contains the following elements, each of which are computed over the provided `bins` considering only artificial stars for which `selectfunc` returned `true`:
- `result[1]` contains the mean input magnitude of the stars in each bin.
- `result[2]` contains the completeness value measured for each bin, defined as the fraction of input stars in each bin for which `selectfunc` returned `true`.
- `result[3]` contains the photometric bias measured for each bin, defined as `statistic(out .- in)`, where `out` are the measured (output) magnitudes and `in` are the intrinsic (input) magnitudes.
- `result[4]` contains the photometric error measured for each bin, defined as `statistic(abs.(out .- in))`, with `out` and `in` defined as above.
# Examples
Let
- `F606Wi` be a vector containing the input magnitudes of your artificial stars
- `F606Wo` be a vector containing the measured magnitudes of the artificial stars, where a value of 99.999 indicates a non-detection.
- `flag` be a vector of booleans that indicates whether the artificial star passed additional quality cuts (star-galaxy separation, etc.)
You could call this method as
```julia
import TypedTables: Table
process_ASTs(Table(input=F606Wi, output=F606Wo, good=flag),
:input, :output, minimum(F606Wi):0.1:maximum(F606Wi),
x -> (x.good==true) & (x.output != 99.999))
```
See also the tests in `test/utilities/process_ASTs_test.jl`.
"""
function process_ASTs end

# Numerical utilities

# function estimate_mode(data::AbstractVector{<:Real})
Expand Down
4 changes: 2 additions & 2 deletions test/fitting/basic_linear_combinations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Distributions: Poisson
import StableRNGs: StableRNG
using Distributions: Poisson
using StableRNGs: StableRNG
import StarFormationHistories as SFH
using Test

Expand Down
8 changes: 4 additions & 4 deletions test/fitting/linear_amr_test.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import Distributions: Poisson
using Distributions: Poisson
import DynamicHMC
import StableRNGs: StableRNG
using StableRNGs: StableRNG
import StarFormationHistories as SFH

import LinearAlgebra: Diagonal
import Random: rand!
using LinearAlgebra: Diagonal
using Random: rand!
using Test


Expand Down
6 changes: 3 additions & 3 deletions test/fitting/log_amr_test.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import Distributions: Poisson
using Distributions: Poisson
# import ForwardDiff
import StarFormationHistories as SFH
import StableRNGs: StableRNG
import Random: rand!
using StableRNGs: StableRNG
using Random: rand!
using Test

# Now try fixed_log_amr that uses an AMR that is logarithmic in [M/H]
Expand Down
15 changes: 10 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import StarFormationHistories as SFH
import InitialMassFunctions: Salpeter1955, Kroupa2001
import Distributions: Poisson, Uniform, pdf, median
using InitialMassFunctions: Salpeter1955, Kroupa2001
using Distributions: Poisson, Uniform, pdf, median
import Random
import StableRNGs: StableRNG
import StaticArrays: SVector
import QuadGK: quadgk
using StableRNGs: StableRNG
using StaticArrays: SVector
using QuadGK: quadgk
import MCMCChains
import DynamicHMC
# import Optim
Expand Down Expand Up @@ -681,6 +681,11 @@ const rtols = (1e-3, 1e-7) # Relative tolerance levels to use for the above floa


@testset "utilities" begin
# Uses extensions, requires Julia >= 1.9
if VERSION >= v"1.9"
@safetestset "process_ASTs" include("utilities/process_ASTs_test.jl")
end

for i in eachindex(float_types, float_type_labels)
label = float_type_labels[i]
@testset "$label" begin
Expand Down
Loading

0 comments on commit 7f89d8a

Please sign in to comment.