Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Canned median quantile #435

Merged
merged 12 commits into from
Feb 8, 2025
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.7
Version: 0.1.8
Authors@R: c(
person("Daniel J.", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat
- Shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag`
- Quantiles produced by `grf` were sometimes out of order.
- dist_quantiles can have all `NA` values without causing unrelated errors
- adjust default quantiles throughout so that they match.
- force `layer_residual_quantiles()` to always include `0.5`.

# epipredict 0.1

Expand Down
11 changes: 6 additions & 5 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ arx_fcast_epi_workflow <- function(
} else {
quantile_levels <- sort(compare_quantile_args(
args_list$quantile_levels,
rlang::eval_tidy(trainer$eng_args$quantiles) %||% c(.1, .5, .9),
rlang::eval_tidy(trainer$eng_args$quantiles) %||%
c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
"grf"
))
trainer$eng_args$quantiles <- rlang::enquo(quantile_levels)
Expand Down Expand Up @@ -253,8 +254,8 @@ arx_fcast_epi_workflow <- function(
#' the last day of data. For example, if the last day of data was 3 days ago,
#' the ahead becomes `ahead+3`.
#' - `"extend_lags"`: increase the lags so they're relative to the actual
#' forecast date. For example, if the lags are `c(0,7,14)` and the last day of
#' data was 3 days ago, the lags become `c(3,10,17)`.
#' forecast date. For example, if the lags are `c(0, 7, 14)` and the last day of
#' data was 3 days ago, the lags become `c(3, 10, 17)`.
#' @param warn_latency by default, `step_adjust_latency` warns the user if the
#' latency is large. If this is `FALSE`, that warning is turned off.
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
Expand Down Expand Up @@ -295,7 +296,7 @@ arx_args_list <- function(
target_date = NULL,
adjust_latency = c("none", "extend_ahead", "extend_lags", "locf"),
warn_latency = TRUE,
quantile_levels = c(0.05, 0.95),
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
Expand Down Expand Up @@ -362,7 +363,7 @@ compare_quantile_args <- function(alist, tlist, train_method = c("qr", "grf")) {
default_alist <- eval(formals(arx_args_list)$quantile_levels)
default_tlist <- switch(train_method,
"qr" = eval(formals(quantile_reg)$quantile_levels),
"grf" = c(.1, .5, .9)
"grf" = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95)
)
if (setequal(alist, default_alist)) {
if (setequal(tlist, default_tlist)) {
Expand Down
23 changes: 11 additions & 12 deletions R/autoplot.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ ggplot2::autoplot
#' step_epi_naomit()
#'
#' f <- frosting() %>%
#' layer_residual_quantiles(
#' quantile_levels = c(.025, .1, .25, .75, .9, .975)
#' ) %>%
#' layer_residual_quantiles() %>%
#' layer_threshold(starts_with(".pred")) %>%
#' layer_add_target_date()
#'
Expand Down Expand Up @@ -85,7 +83,7 @@ NULL
#' @rdname autoplot-epipred
autoplot.epi_workflow <- function(
object, predictions = NULL,
.levels = c(.5, .8, .95), ...,
.levels = c(.5, .8, .9), ...,
.color_by = c("all_keys", "geo_value", "other_keys", ".response", "all", "none"),
.facet_by = c(".response", "other_keys", "all_keys", "geo_value", "all", "none"),
.base_color = "dodgerblue4",
Expand Down Expand Up @@ -183,7 +181,7 @@ autoplot.epi_workflow <- function(
}

if (".pred" %in% names(predictions)) {
ntarget_dates <- n_distinct(predictions$time_value)
ntarget_dates <- dplyr::n_distinct(predictions$time_value)
if (ntarget_dates > 1L) {
bp <- bp +
geom_line(
Expand Down Expand Up @@ -231,24 +229,25 @@ starts_with_impl <- function(x, vars) {

plot_bands <- function(
base_plot, predictions,
levels = c(.5, .8, .95),
levels = c(.5, .8, .9),
fill = "blue4",
alpha = 0.6,
linewidth = 0.05) {
innames <- names(predictions)
n <- length(levels)
alpha <- alpha / (n - 1)
l <- (1 - levels) / 2
l <- c(rev(l), 1 - l)
n_levels <- length(levels)
alpha <- alpha / (n_levels - 1)
# generate the corresponding level that is 1 - level
levels <- (1 - levels) / 2
levels <- c(rev(levels), 1 - levels)

ntarget_dates <- dplyr::n_distinct(predictions$time_value)

predictions <- predictions %>%
mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, l), l)) %>%
mutate(.pred_distn = dist_quantiles(quantile(.pred_distn, levels), levels)) %>%
pivot_quantiles_wider(.pred_distn)
qnames <- setdiff(names(predictions), innames)

for (i in 1:n) {
for (i in 1:n_levels) {
bottom <- qnames[i]
top <- rev(qnames)[i]
if (i == 1) {
Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#' @examples
#' f <- frosting() %>%
#' layer_predict() %>%
#' layer_residual_quantiles(quantile_levels = c(0.0275, 0.975), symmetrize = FALSE) %>%
#' layer_residual_quantiles(symmetrize = FALSE) %>%
#' layer_naomit(.pred)
#'
#' extract_argument(f, "layer_residual_quantiles", "symmetrize")
Expand Down
2 changes: 1 addition & 1 deletion R/flatline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ flatline_args_list <- function(
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
quantile_levels = c(0.05, 0.95),
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = character(0L),
Expand Down
2 changes: 1 addition & 1 deletion R/layer_quantile_distn.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
#' p
layer_quantile_distn <- function(frosting,
...,
quantile_levels = c(.25, .75),
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
truncate = c(-Inf, Inf),
name = ".pred_distn",
id = rand_id("quantile_distn")) {
Expand Down
8 changes: 5 additions & 3 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#' @param frosting a `frosting` postprocessor
#' @param ... Unused, include for consistency with other layers.
#' @param quantile_levels numeric vector of probabilities with values in (0,1)
#' referring to the desired quantile.
#' referring to the desired quantile. Note that 0.5 will always be included
#' even if left out by the user.
#' @param symmetrize logical. If `TRUE` then interval will be symmetric.
#' @param by_key A character vector of keys to group the residuals by before
#' calculating quantiles. The default, `c()` performs no grouping.
Expand All @@ -28,7 +29,7 @@
#' f <- frosting() %>%
#' layer_predict() %>%
#' layer_residual_quantiles(
#' quantile_levels = c(0.0275, 0.975),
#' quantile_levels = c(0.025, 0.975),
#' symmetrize = FALSE
#' ) %>%
#' layer_naomit(.pred)
Expand All @@ -48,7 +49,7 @@
#' p2 <- forecast(wf2)
layer_residual_quantiles <- function(
frosting, ...,
quantile_levels = c(0.05, 0.95),
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
symmetrize = TRUE,
by_key = character(0L),
name = ".pred_distn",
Expand All @@ -59,6 +60,7 @@ layer_residual_quantiles <- function(
arg_is_chr(by_key, allow_empty = TRUE)
arg_is_probabilities(quantile_levels)
arg_is_lgl(symmetrize)
quantile_levels <- sort(unique(c(0.5, quantile_levels)))
add_layer(
frosting,
layer_residual_quantiles_new(
Expand Down
2 changes: 1 addition & 1 deletion R/make_grf_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ make_grf_quantiles <- function() {
data = c(x = "X", y = "Y"),
func = c(pkg = "grf", fun = "quantile_forest"),
defaults = list(
quantiles = c(0.1, 0.5, 0.9),
quantiles = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
num.threads = 1L,
seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max))
)
Expand Down
6 changes: 4 additions & 2 deletions R/make_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' @param engine Character string naming the fitting function. Currently, only
#' "rq" and "grf" are supported.
#' @param quantile_levels A scalar or vector of values in (0, 1) to determine which
#' quantiles to estimate (default is 0.5).
#' quantiles to estimate (default is the set 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95).
#' @param method A fitting method used by [quantreg::rq()]. See the
#' documentation for a list of options.
#'
Expand All @@ -27,7 +27,9 @@
#' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq")
#' ff <- rq_spec %>% fit(y ~ ., data = tib)
#' predict(ff, new_data = tib)
quantile_reg <- function(mode = "regression", engine = "rq", quantile_levels = 0.5, method = "br") {
quantile_reg <- function(mode = "regression", engine = "rq",
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
method = "br") {
# Check for correct mode
if (mode != "regression") {
cli_abort("`mode` must be 'regression'")
Expand Down
9 changes: 2 additions & 7 deletions R/make_smooth_quantile_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@
#' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the
#' only supported engine is [smoothqr::smooth_qr()].
#'
#' @param mode A single character string for the type of model.
#' The only possible value for this model is "regression".
#' @param engine Character string naming the fitting function. Currently, only
#' "smooth_qr" is supported.
#' @param quantile_levels A scalar or vector of values in (0, 1) to determine which
#' quantiles to estimate (default is 0.5).
#' @inheritParams quantile_reg
#' @param outcome_locations Defaults to the vector `1:ncol(y)` but if the
#' responses are observed at a different spacing (or appear in a different
#' order), that information should be used here. This
Expand Down Expand Up @@ -76,7 +71,7 @@ smooth_quantile_reg <- function(
mode = "regression",
engine = "smoothqr",
outcome_locations = NULL,
quantile_levels = 0.5,
quantile_levels = c(0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95),
degree = 3L) {
# Check for correct mode
if (mode != "regression") cli_abort("`mode` must be 'regression'")
Expand Down
6 changes: 3 additions & 3 deletions man/arx_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/arx_class_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions man/autoplot-epipred.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/extract_argument.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/flatline_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/grf_quantiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/layer_quantile_distn.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions man/layer_residual_quantiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/quantile_reg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/smooth_quantile_reg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading