Skip to content

Commit

Permalink
move gc data protection to R side
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Dec 15, 2024
1 parent 5502558 commit ce5ce97
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 14 deletions.
42 changes: 37 additions & 5 deletions R-package/R/xgb.DMatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ xgb.QuantileDMatrix <- function(
)
data_iterator <- .single.data.iterator(iterator_env)

env_keep_alive <- new.env()
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL

# Note: the ProxyDMatrix has its finalizer assigned in the R externalptr
# object, but that finalizer will only be called once the object is
# garbage-collected, which doesn't happen immediately after it goes out
Expand All @@ -363,9 +367,11 @@ xgb.QuantileDMatrix <- function(
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL
return(data_iterator$f_reset(iterator_env))
}
calling_env <- environment()
Expand Down Expand Up @@ -553,7 +559,9 @@ xgb.DataBatch <- function(
}

# This is only for internal usage, class is not exposed to the user.
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) {
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL
lst <- data_iterator$f_next(data_iterator$env)
if (is.null(lst)) {
return(0L)
Expand All @@ -566,13 +574,21 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) {
stop("Either one of 'group' or 'qid' should be NULL")
}
if (is.data.frame(lst$data)) {
tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types)
data <- lst$data
lst <- within(lst, rm("data"))
tmp <- .process.df.for.dmatrix(data, lst$feature_types)
lst$feature_types <- tmp$feature_types
env_keep_alive$keepalive1 <- lst
env_keep_alive$keepalive2 <- tmp
data <- NULL
.Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst)
} else if (is.matrix(lst$data)) {
env_keep_alive$keepalive1 <- lst
.Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data)
} else if (inherits(lst$data, "dgRMatrix")) {
tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data))
env_keep_alive$keepalive1 <- lst
env_keep_alive$keepalive2 <- tmp
.Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp)
} else {
stop("'data' has unsupported type.")
Expand Down Expand Up @@ -707,14 +723,25 @@ xgb.ExtMemDMatrix <- function(
cache_prefix <- path.expand(cache_prefix)
nthread <- as.integer(NVL(nthread, -1L))

# The purpose of this environment is to keep data alive (protected from the
# garbage collector) after setting the data in the proxy dmatrix. The data
# held here (under names 'keepalive1' and 'keepalive2') should be unset
# (leaving it unprotected for garbage collection) before the start of each
# data iteration batch and during each iterator reset.
env_keep_alive <- new.env()
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL

proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down Expand Up @@ -774,14 +801,19 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint

nthread <- as.integer(NVL(nthread, -1L))

env_keep_alive <- new.env()
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL
proxy_handle <- .make.proxy.handle()
on.exit({
.Call(XGDMatrixFree_R, proxy_handle)
})
iterator_next <- function() {
return(xgb.ProxyDMatrix(proxy_handle, data_iterator))
return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive))
}
iterator_reset <- function() {
env_keep_alive$keepalive1 <- NULL
env_keep_alive$keepalive2 <- NULL
return(data_iterator$f_reset(data_iterator$env))
}
calling_env <- environment()
Expand Down
12 changes: 3 additions & 9 deletions R-package/src/xgboost_R.cc
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) {
{
std::string array_str = MakeArrayInterfaceFromRMat(R_mat);
res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str());
R_SetExternalPtrProtected(handle, R_mat);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -708,7 +707,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) {
array_str_indices.c_str(),
array_str_data.c_str(),
ncol);
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -722,7 +720,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) {
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst);
res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str());
R_SetExternalPtrProtected(handle, lst);
}
CHECK_CALL(res_code);
R_API_END();
Expand All @@ -736,20 +733,17 @@ struct _RDataIterator {
SEXP f_reset;
SEXP calling_env;
SEXP continuation_token;
SEXP proxy_dmat;

_RDataIterator(
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) :
SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) :
f_next(f_next), f_reset(f_reset), calling_env(calling_env),
continuation_token(continuation_token), proxy_dmat(proxy_dmat) {}
continuation_token(continuation_token) {}

void reset() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SafeExecFun(this->f_reset, this->calling_env, this->continuation_token);
}

int next() {
R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue);
SEXP R_res = Rf_protect(
SafeExecFun(this->f_next, this->calling_env, this->continuation_token));
int res = Rf_asInteger(R_res);
Expand Down Expand Up @@ -777,7 +771,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R(

int res_code;
try {
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat);
_RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token);

std::string str_cache_prefix;
xgboost::Json jconfig{xgboost::Object{}};
Expand Down

0 comments on commit ce5ce97

Please sign in to comment.