Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better handling for svyglm models and fix for #39 #41

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@
^vignettes/.+\.pdf$
^vignettes/.+\.sty$
^vignettes/.+\.tex$
^.*\.Rproj$
^\.Rproj\.user$
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Type: Package
Title: Tidy, Type-Safe 'prediction()' Methods
Description: A one-function package containing 'prediction()', a type-safe alternative to 'predict()' that always returns a data frame. The 'summary()' method provides a data frame with average predictions, possibly over counterfactual versions of the data (a la the 'margins' command in 'Stata'). Marginal effect estimation is provided by the related package, 'margins' <https://cran.r-project.org/package=margins>. The package currently supports common model types (e.g., "lm", "glm") from the 'stats' package, as well as numerous other model classes from other add-on packages. See the README or main package documentation page for a complete listing.
License: MIT + file LICENSE
Version: 0.3.14
Date: 2019-06-16
Version: 0.3.15
Date: 2019-08-08
Authors@R: c(person("Thomas J.", "Leeper",
role = c("aut", "cre"),
email = "[email protected]",
Expand All @@ -13,7 +13,10 @@ Authors@R: c(person("Thomas J.", "Leeper",
email = "[email protected]"),
person("Vincent", "Arel-Bundock", role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0003-2042-7063"))
comment = c(ORCID = "0000-0003-2042-7063")),
person("Tomasz", "\u017b\u00F3\u0142tak", role = "ctb",
email = "[email protected]",
comment = c(ORCID = "0000-0003-1354-4472"))
)
URL: https://github.com/leeper/prediction
BugReports: https://github.com/leeper/prediction/issues
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ S3method(find_data,hxlr)
S3method(find_data,lm)
S3method(find_data,mca)
S3method(find_data,merMod)
S3method(find_data,survey.design)
S3method(find_data,svyglm)
S3method(find_data,svyrep.design)
S3method(find_data,train)
S3method(find_data,vgam)
S3method(find_data,vglm)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# prediction 0.3.15

* `prediction.svyglm` handles survey design objects as `data` argument.
* `prediction.svyglm` handles `data` with NAs.
* `build_datalist` preserves levels of factors that are mentioned in `at` argument.

# prediction 0.3.13

* Fixed a bug in `prediction_glm` with the `data` argument (Issue #32).
Expand Down
31 changes: 17 additions & 14 deletions R/build_datalist.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,25 @@
#' @seealso \code{\link{find_data}}, \code{\link{mean_or_mode}}, \code{\link{seq_range}}
#' @importFrom data.table rbindlist
#' @export
build_datalist <-
build_datalist <-
function(data,
at = NULL,
at = NULL,
as.data.frame = FALSE,
...){

# check for `at` specification and `as.data.frame` arguments
if (!is.null(at) && length(at) > 0) {
# check `at` specification against data
check_at(data, at)

# setup list of data.frames based on at
data_out <- set_data_to_at(data, at = at)
at_specification <- cbind(index = seq_len(nrow(data_out[["at"]])), data_out[["at"]])
data_out <- data_out[["data"]]

if (isTRUE(as.data.frame)) {
data_out <- data.table::rbindlist(data_out)
}

} else if (isTRUE(as.data.frame)) {
# if `at` empty and `as.data.frame = TRUE`, simply return original data
data_out <- data
Expand All @@ -54,10 +53,10 @@ function(data,
check_at <- function(data, at) {
# check names of `at`
check_at_names(names(data), at)

# check factor levels specified in `at`
check_factor_levels(data, at)

# check values of numeric values are interpolations
check_values(data, at)
}
Expand All @@ -71,16 +70,16 @@ check_factor_levels <- function(data, at) {
levels(factor(v))
} else {
NULL
}
}
})
levels <- levels[!sapply(levels, is.null)]
at <- at[names(at) %in% names(levels)]
for (i in seq_along(at)) {
atvals <- as.character(at[[i]])
x <- atvals %in% levels[[names(at)[i]]]
if (!all(x)) {
stop(paste0("Illegal factor levels for variable '", names(at)[i], "': ",
paste0(shQuote(atvals[!x]), collapse = ", ")),
stop(paste0("Illegal factor levels for variable '", names(at)[i], "': ",
paste0(shQuote(atvals[!x]), collapse = ", ")),
call. = FALSE)
}
}
Expand All @@ -90,7 +89,7 @@ check_factor_levels <- function(data, at) {
check_values <- function(data, at) {
# drop variables not in `at`
dat <- data[, names(at), drop = FALSE]

# drop non-numeric variables from `dat` and `at`
not_numeric <- !sapply(dat, class) %in% c("character", "factor", "ordered", "logical")
at <- at[names(at) %in% names(dat)[not_numeric]]
Expand All @@ -100,7 +99,7 @@ check_values <- function(data, at) {
# calculate variable ranges
limits <- do.call(rbind, lapply(dat, range, na.rm = TRUE))
rownames(limits) <- names(dat)

# check ranges
for (i in seq_along(at)) {
out <- (at[[i]] < limits[names(at)[i],1]) | (at[[i]] > limits[names(at)[i],2])
Expand Down Expand Up @@ -136,9 +135,13 @@ set_data_to_at <- function(data, at = NULL) {
} else {
expanded <- expand.grid(at, KEEP.OUT.ATTRS = FALSE)
}
e <- split(expanded, unique(expanded))
for (i in intersect(names(data)[sapply(data, is.factor)], names(expanded))) {
expanded[, i] <- factor(expanded[[i]], levels(data[[i]]))
}
e <- split(expanded, unique(expanded), drop = TRUE)
data_out <- lapply(e, function(atvals) {
dat <- data

dat <- `[<-`(dat, , names(atvals), value = atvals)
structure(dat, at = as.list(atvals))
})
Expand Down
27 changes: 26 additions & 1 deletion R/find_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' require("datasets")
#' x <- lm(mpg ~ cyl * hp + wt, data = head(mtcars))
#' find_data(x)
#'
#'
#' @seealso \code{\link{prediction}}, \code{\link{build_datalist}}, \code{\link{mean_or_mode}}, \code{\link{seq_range}}
#' @export
find_data <- function(model, ...) {
Expand Down Expand Up @@ -107,9 +107,34 @@ find_data.merMod <- function(model, env = parent.frame(), ...) {
#' @export
find_data.svyglm <- function(model, ...) {
data <- model[["data"]]
# handle subset
if (!is.null(model[["call"]][["subset"]])) {
subs <- try(eval(model[["call"]][["subset"]], data), silent = TRUE)
if (inherits(subs, "try-error")) {
subs <- TRUE
warning("'find_data()' cannot locate variable(s) used in 'subset'")
}
data <- data[subs, , drop = FALSE]
}
# handle na.action
if (!is.null(model[["na.action"]])) {
data <- data[-model[["na.action"]], , drop = FALSE]
}
data
}

#' @rdname find_data
#' @export
find_data.survey.design <- function(model, ...) {
model[["variables"]]
}

#' @rdname find_data
#' @export
find_data.svyrep.design <- function(model, ...) {
model[["variables"]]
}

#' @rdname find_data
#' @export
find_data.train <- function(model, ...) {
Expand Down
31 changes: 19 additions & 12 deletions R/prediction_svyglm.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
#' @rdname prediction
#' @export
prediction.svyglm <-
function(model,
data = find_data(model, parent.frame()),
at = NULL,
type = c("response", "link"),
prediction.svyglm <-
function(model,
data = find_data(model, parent.frame()),
at = NULL,
type = c("response", "link"),
calculate_se = TRUE,
...) {

type <- match.arg(type)

# extract predicted values
data <- data
if (missing(data) || is.null(data)) {
pred <- predict(model, type = type, se.fit = TRUE, ...)
pred <- data.frame(fitted = unclass(pred),
pred <- data.frame(fitted = unclass(pred),
se.fitted = sqrt(unname(attributes(pred)[["var"]])))
} else {
if (inherits(data, c("survey.design", "svyrep.design"))) {
data <- find_data(data)
}
# setup data
if (is.null(at)) {
out <- data
Expand All @@ -26,14 +29,18 @@ function(model,
}
# calculate predictions
tmp <- predict(model, newdata = out, type = type, se.fit = TRUE, ...)
pred <- make_data_frame(out, fitted = unclass(tmp), se.fitted = sqrt(unname(attributes(tmp)[["var"]])))
se.fitted <- fitted <- rep(NA_real_, nrow(out))
noNAs <- rownames(out) %in% names(tmp)
se.fitted[noNAs] <- sqrt(unname(attributes(tmp)[["var"]]))
fitted[noNAs] <- unclass(tmp)
pred <- make_data_frame(out, fitted = fitted, se.fitted = se.fitted)
}

# variance(s) of average predictions
vc <- NA_real_

# output
structure(pred,
structure(pred,
class = c("prediction", "data.frame"),
at = if (is.null(at)) at else at_specification,
type = type,
Expand Down
6 changes: 6 additions & 0 deletions man/find_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 12 additions & 6 deletions tests/testthat/tests-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ if (require("sampleSelection", quietly = TRUE)) {
test_that("Test prediction() for 'selection'", {
data("Mroz87", package = "sampleSelection")
Mroz87$kids <- (Mroz87$kids5 + Mroz87$kids618 > 0)
m <- sampleSelection::heckit(lfp ~ age + I( age^2 ) + faminc + kids + educ,
m <- sampleSelection::heckit(lfp ~ age + I( age^2 ) + faminc + kids + educ,
wage ~ exper + I( exper^2 ) + educ + city, Mroz87)
p <- prediction(m)
expect_true(inherits(p, "prediction"), label = "'prediction' class is correct")
Expand Down Expand Up @@ -570,16 +570,22 @@ if (require("survey", quietly = TRUE)) {
p <- prediction(m)
expect_true(inherits(p, "prediction"), label = "'prediction' class is correct")
expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned")
dstrat2 <- subset(dstrat, yr.rnd == "No")
dstrat2$variables$enroll[10] = NA
p2 <- prediction(m, dstrat2)
expect_true(inherits(p2, "prediction"), label = "'prediction' class is correct")
expect_true(all(c("fitted", "se.fitted") %in% names(p2)), label = "'fitted' and 'se.fitted' columns returned")
expect_true(is.na(p2$fitted[10]) & is.na(p2$se.fitted[10]), label = "NAs in data handled correctly")
})
}

if (require("survival", quietly = TRUE)) {
test_that("Test prediction() for 'coxph'", {
test1 <- list(time=c(4,3,1,1,2,2,3),
status=c(1,1,1,0,1,1,0),
x=c(0,2,1,1,1,0,0),
sex=c(0,0,0,0,1,1,1))
m <- survival::coxph(survival::Surv(time, status) ~ x + survival::strata(sex), test1)
test1 <- list(time=c(4,3,1,1,2,2,3),
status=c(1,1,1,0,1,1,0),
x=c(0,2,1,1,1,0,0),
sex=c(0,0,0,0,1,1,1))
m <- survival::coxph(survival::Surv(time, status) ~ x + survival::strata(sex), test1)
p <- prediction(m)
expect_true(inherits(p, "prediction"), label = "'prediction' class is correct")
expect_true(all(c("fitted", "se.fitted") %in% names(p)), label = "'fitted' and 'se.fitted' columns returned")
Expand Down