Skip to content

Commit

Permalink
Update ctree: The rest (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Nov 11, 2020
1 parent d508b87 commit f1a693e
Show file tree
Hide file tree
Showing 35 changed files with 1,351 additions and 88 deletions.
5 changes: 4 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ linters: with_defaults(
seq_linter = NULL
)
exclusions: list(
"inst/scripts/example.R",
"inst/scripts/compare_shap_python.R",
"inst/scripts/create_lm_model_object.R",
"inst/scripts/create_xgboost_model_object.R",
"inst/scripts/example_ctree_model.R",
"inst/scripts/example_custom_model.R",
"inst/scripts/readme_example.R",
"inst/scripts/shap_python_script.py",
"R/RcppExports.R",
"R/zzz.R"
)
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(explain,combined)
S3method(explain,copula)
S3method(explain,ctree)
S3method(explain,ctree_comb_mincrit)
S3method(explain,empirical)
S3method(explain,gaussian)
S3method(features,gam)
Expand All @@ -23,11 +25,13 @@ S3method(predict_model,lm)
S3method(predict_model,ranger)
S3method(predict_model,xgb.Booster)
S3method(prepare_data,copula)
S3method(prepare_data,ctree)
S3method(prepare_data,empirical)
S3method(prepare_data,gaussian)
export(aicc_full_single_cpp)
export(apply_dummies)
export(correction_matrix_cpp)
export(create_ctree)
export(explain)
export(feature_combinations)
export(feature_matrix_cpp)
Expand Down
144 changes: 134 additions & 10 deletions R/explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#'
#' @param approach Character vector of length \code{1} or \code{n_features}.
#' \code{n_features} equals the total number of features in the model. All elements should
#' either be \code{"gaussian"}, \code{"copula"} or \code{"empirical"}. See details for more information.
#' either be \code{"gaussian"}, \code{"copula"}, \code{"empirical"}, or \code{"ctree"}. See details for more
#' information.
#'
#' @param prediction_zero Numeric. The prediction value for unseen data, typically equal to the mean of
#' the response.
Expand Down Expand Up @@ -87,16 +88,20 @@
#' # Gaussian copula approach
#' explain3 <- explain(x_test, explainer, approach = "copula", prediction_zero = p, n_samples = 1e2)
#'
#' # ctree approach
#' explain4 <- explain(x_test, explainer, approach = "ctree", prediction_zero = p)
#'
#' # Combined approach
#' approach <- c("gaussian", "gaussian", "empirical", "empirical")
#' explain4 <- explain(x_test, explainer, approach = approach, prediction_zero = p, n_samples = 1e2)
#' explain5 <- explain(x_test, explainer, approach = approach, prediction_zero = p, n_samples = 1e2)
#'
#' # Print the Shapley values
#' print(explain1$dt)
#'
#' # Plot the results
#' plot(explain1)
explain <- function(x, explainer, approach, prediction_zero, ...) {
extras <- list(...)

# Check input for x
if (!is.matrix(x) & !is.data.frame(x)) {
Expand All @@ -105,14 +110,14 @@ explain <- function(x, explainer, approach, prediction_zero, ...) {

# Check input for approach
if (!(is.vector(approach) &&
is.atomic(approach) &&
(length(approach) == 1 | length(approach) == length(explainer$feature_labels)) &&
all(is.element(approach, c("empirical", "gaussian", "copula"))))
is.atomic(approach) &&
(length(approach) == 1 | length(approach) == length(explainer$feature_labels)) &&
all(is.element(approach, c("empirical", "gaussian", "copula", "ctree"))))
) {
stop(
paste(
"It seems that you passed a non-valid value for approach.",
"It should be either 'empirical', 'gaussian', 'copula' or",
"It should be either 'empirical', 'gaussian', 'copula', 'ctree' or",
"a vector of length=ncol(x) with only the above characters."
)
)
Expand All @@ -131,6 +136,8 @@ explain <- function(x, explainer, approach, prediction_zero, ...) {

if (length(approach) > 1) {
class(x) <- "combined"
} else if (length(extras$mincriterion) > 1) {
class(x) <- "ctree_comb_mincrit"
} else {
class(x) <- approach
}
Expand Down Expand Up @@ -277,10 +284,60 @@ explain.copula <- function(x, explainer, approach, prediction_zero, ...) {
return(r)
}


#' @param mincriterion Numeric value or vector where length of vector is the number of features in model.
#' Value is equal to 1 - alpha where alpha is the nominal level of the conditional
#' independence tests.
#' If it is a vector, this indicates which mincriterion to use
#' when conditioning on various numbers of features.
#'
#' @param minsplit Numeric value. Equal to the value that the sum of the left and right daughter nodes need to exceed.
#'
#' @param minbucket Numeric value. Equal to the minimum sum of weights in a terminal node.
#'
#' @param sample Boolean. If TRUE, then the method always samples \code{n_samples} from the leaf (with replacement).
#' If FALSE and the number of obs in the leaf is less than \code{n_samples}, the method will take all observations
#' in the leaf. If FALSE and the number of obs in the leaf is more than \code{n_samples}, the method will sample
#' \code{n_samples} (with replacement). This means that there will always be sampling in the leaf unless
#' \code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.
#
#' @rdname explain
#' @name explain
#'
#' @export
explain.combined <- function(x, explainer, approach, prediction_zero, mu = NULL, cov_mat = NULL, ...) {
explain.ctree <- function(x, explainer, approach, prediction_zero,
mincriterion = 0.95, minsplit = 20,
minbucket = 7, sample = TRUE, ...) {
# Checks input argument
if (!is.matrix(x) & !is.data.frame(x)) {
stop("x should be a matrix or a dataframe.")
}

# Add arguments to explainer object
explainer$x_test <- explainer_x_test_dt(x, explainer$feature_labels)
explainer$approach <- approach
explainer$mincriterion <- mincriterion
explainer$minsplit <- minsplit
explainer$minbucket <- minbucket
explainer$sample <- sample

# Generate data
dt <- prepare_data(explainer, ...)

if (!is.null(explainer$return)) return(dt)

# Predict
r <- prediction(dt, prediction_zero, explainer)

return(r)
}

#' @rdname explain
#' @name explain
#'
#' @export
explain.combined <- function(x, explainer, approach, prediction_zero,
mu = NULL, cov_mat = NULL, ...) {
# Get indices of combinations
l <- get_list_approaches(explainer$X$n_features, approach)
explainer$return <- TRUE
Expand All @@ -290,13 +347,11 @@ explain.combined <- function(x, explainer, approach, prediction_zero, mu = NULL,
for (i in seq_along(l)) {
dt_l[[i]] <- explain(x, explainer, approach = names(l)[i], prediction_zero, index_features = l[[i]], ...)
}

dt <- data.table::rbindlist(dt_l, use.names = TRUE)

r <- prediction(dt, prediction_zero, explainer)

return(r)

}

#' Helper function used in \code{\link{explain.combined}}
Expand All @@ -314,7 +369,6 @@ explain.combined <- function(x, explainer, approach, prediction_zero, mu = NULL,
#' @return List
#'
get_list_approaches <- function(n_features, approach) {

l <- list()
approach[length(approach)] <- approach[length(approach) - 1]

Expand All @@ -334,7 +388,77 @@ get_list_approaches <- function(n_features, approach) {
if (length(x) > 0) {
if (approach[1] == "copula") x <- c(0, x)
l$copula <- which(n_features %in% x)
}

x <- which(approach == "ctree")
if (length(x) > 0) {
if (approach[1] == "ctree") x <- c(0, x)
l$ctree <- which(n_features %in% x)
}
return(l)
}

#' @keywords internal
explainer_x_test <- function(x_test, feature_labels) {

# Remove variables that were not used for training
x <- data.table::as.data.table(x_test)
cnms_remove <- setdiff(colnames(x), feature_labels)
if (length(cnms_remove) > 0) x[, (cnms_remove) := NULL]
data.table::setcolorder(x, feature_labels)

return(as.matrix(x))
}

#' @keywords internal
explainer_x_test_dt <- function(x_test, feature_labels) {

# Remove variables that were not used for training
# Same as explainer_x_test() but doesn't convert to a matrix
# Useful for ctree method which sometimes takes categorical features
x <- data.table::as.data.table(x_test)
cnms_remove <- setdiff(colnames(x), feature_labels)
if (length(cnms_remove) > 0) x[, (cnms_remove) := NULL]
data.table::setcolorder(x, feature_labels)

return(x)
}


#' @rdname explain
#' @name explain
#'
#' @export
explain.ctree_comb_mincrit <- function(x, explainer, approach,
prediction_zero, mincriterion, ...) {

# Get indices of combinations
l <- get_list_ctree_mincrit(explainer$X$n_features, mincriterion)
explainer$return <- TRUE # this is important so that you don't use prediction() twice
explainer$x_test <- as.matrix(x)

dt_l <- list()
for (i in seq_along(l)) {
dt_l[[i]] <- explain(x, explainer, approach, prediction_zero,
index_features = l[[i]],
mincriterion = as.numeric(names(l[i])), ...)
}

dt <- data.table::rbindlist(dt_l, use.names = TRUE)

r <- prediction(dt, prediction_zero, explainer)
return(r)
}

#' @keywords internal
get_list_ctree_mincrit <- function(n_features, mincriterion) {
l <- list()

for (k in 1:length(unique(mincriterion))) {
x <- which(mincriterion == unique(mincriterion)[k])
nn <- as.character(unique(mincriterion)[k])
if (length(l) == 0) x <- c(0, x)
l[[nn]] <- which(n_features %in% x)
}
return(l)
}
Expand Down
7 changes: 3 additions & 4 deletions R/features.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ feature_not_exact <- function(m, n_combinations = 200, weight_zero_m = 10^6) {
#' @keywords internal
helper_feature <- function(m, feature_sample) {

sample_frequence <- is_duplicate <- NULL # due to NSE notes in R CMD check
sample_frequence <- is_duplicate <- NULL # due to NSE notes in R CMD check

x <- feature_matrix_cpp(feature_sample, m)
dt <- data.table::data.table(x)
Expand Down Expand Up @@ -208,7 +208,7 @@ helper_feature <- function(m, feature_sample) {
make_dummies <- function(data) {

contrasts <- features <- factor_features <- NULL # due to NSE notes in R CMD check
if(is.null(colnames(data))){
if (is.null(colnames(data))) {
stop("data must have column names.")
}
data <- data.table::as.data.table(as.data.frame(data, stringsAsFactors = FALSE))
Expand Down Expand Up @@ -279,7 +279,7 @@ apply_dummies <- function(obj, newdata) {
if (is.null(newdata)) {
stop("newdata needs to be included.")
}
if(is.null(colnames(newdata))){
if (is.null(colnames(newdata))) {
stop("newdata must have column names.")
}
newdata <- data.table::as.data.table(as.data.frame(newdata, stringsAsFactors = FALSE))
Expand Down Expand Up @@ -309,4 +309,3 @@ apply_dummies <- function(obj, newdata) {
contrasts.arg = obj$contrasts_list)
return(x)
}

6 changes: 2 additions & 4 deletions R/models.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ predict_model.xgb.Booster <- function(x, newdata) {
# Test model type
model_type <- model_type(x)

if (model_type %in% c("cat_regression","cat_classification") ) {
if (model_type %in% c("cat_regression", "cat_classification" ) ) {
newdata_dummy <- apply_dummies(obj = x$dummylist, newdata = newdata)
predict(x, as.matrix(newdata_dummy))
} else {
Expand Down Expand Up @@ -246,7 +246,7 @@ model_type.gam <- function(x) {
model_type.xgb.Booster <- function(x) {

if (!is.null(x$params$objective) &&
(x$params$objective == "multi:softmax" | x$params$objective == "multi:softprob")
(x$params$objective == "multi:softmax" | x$params$objective == "multi:softprob")
) {
stop(
paste0(
Expand Down Expand Up @@ -373,7 +373,6 @@ features.ranger <- function(x, cnms, feature_labels = NULL) {
if (!all(nms %in% cnms) | is.null(nms)) error_feature_labels()

return(nms)

}

#' @rdname features
Expand All @@ -392,7 +391,6 @@ features.gam <- function(x, cnms, feature_labels = NULL) {
#' @rdname features
#' @export
features.xgb.Booster <- function(x, cnms, feature_labels = NULL) {

if (!is.null(feature_labels)) message_features_labels()

nms <- x$feature_names
Expand Down
Loading

0 comments on commit f1a693e

Please sign in to comment.