diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index 66bd7205570b..80fa57b0498f 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -353,6 +353,10 @@ xgb.QuantileDMatrix <- function( ) data_iterator <- .single.data.iterator(iterator_env) + env_keep_alive <- new.env() + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL + # Note: the ProxyDMatrix has its finalizer assigned in the R externalptr # object, but that finalizer will only be called once the object is # garbage-collected, which doesn't happen immediately after it goes out @@ -363,9 +367,11 @@ xgb.QuantileDMatrix <- function( .Call(XGDMatrixFree_R, proxy_handle) }) iterator_next <- function() { - return(xgb.ProxyDMatrix(proxy_handle, data_iterator)) + return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive)) } iterator_reset <- function() { + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL return(data_iterator$f_reset(iterator_env)) } calling_env <- environment() @@ -553,7 +559,9 @@ xgb.DataBatch <- function( } # This is only for internal usage, class is not exposed to the user. -xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) { +xgb.ProxyDMatrix <- function(proxy_handle, data_iterator, env_keep_alive) { + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL lst <- data_iterator$f_next(data_iterator$env) if (is.null(lst)) { return(0L) @@ -566,13 +574,21 @@ xgb.ProxyDMatrix <- function(proxy_handle, data_iterator) { stop("Either one of 'group' or 'qid' should be NULL") } if (is.data.frame(lst$data)) { - tmp <- .process.df.for.dmatrix(lst$data, lst$feature_types) + data <- lst$data + lst <- within(lst, rm("data")) + tmp <- .process.df.for.dmatrix(data, lst$feature_types) lst$feature_types <- tmp$feature_types + env_keep_alive$keepalive1 <- lst + env_keep_alive$keepalive2 <- tmp + data <- NULL .Call(XGProxyDMatrixSetDataColumnar_R, proxy_handle, tmp$lst) } else if (is.matrix(lst$data)) { + env_keep_alive$keepalive1 <- lst .Call(XGProxyDMatrixSetDataDense_R, proxy_handle, lst$data) } else if (inherits(lst$data, "dgRMatrix")) { tmp <- list(p = lst$data@p, j = lst$data@j, x = lst$data@x, ncol = ncol(lst$data)) + env_keep_alive$keepalive1 <- lst + env_keep_alive$keepalive2 <- tmp .Call(XGProxyDMatrixSetDataCSR_R, proxy_handle, tmp) } else { stop("'data' has unsupported type.") @@ -707,14 +723,25 @@ xgb.ExtMemDMatrix <- function( cache_prefix <- path.expand(cache_prefix) nthread <- as.integer(NVL(nthread, -1L)) + # The purpose of this environment is to keep data alive (protected from the + # garbage collector) after setting the data in the proxy dmatrix. The data + # held here (under names 'keepalive1' and 'keepalive2') should be unset + # (leaving it unprotected for garbage collection) before the start of each + # data iteration batch and during each iterator reset. + env_keep_alive <- new.env() + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL + proxy_handle <- .make.proxy.handle() on.exit({ .Call(XGDMatrixFree_R, proxy_handle) }) iterator_next <- function() { - return(xgb.ProxyDMatrix(proxy_handle, data_iterator)) + return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive)) } iterator_reset <- function() { + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL return(data_iterator$f_reset(data_iterator$env)) } calling_env <- environment() @@ -774,14 +801,19 @@ xgb.QuantileDMatrix.from_iterator <- function( # nolint nthread <- as.integer(NVL(nthread, -1L)) + env_keep_alive <- new.env() + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL proxy_handle <- .make.proxy.handle() on.exit({ .Call(XGDMatrixFree_R, proxy_handle) }) iterator_next <- function() { - return(xgb.ProxyDMatrix(proxy_handle, data_iterator)) + return(xgb.ProxyDMatrix(proxy_handle, data_iterator, env_keep_alive)) } iterator_reset <- function() { + env_keep_alive$keepalive1 <- NULL + env_keep_alive$keepalive2 <- NULL return(data_iterator$f_reset(data_iterator$env)) } calling_env <- environment() diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index adb9649bf33d..0e7234a18708 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -687,7 +687,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataDense_R(SEXP handle, SEXP R_mat) { { std::string array_str = MakeArrayInterfaceFromRMat(R_mat); res_code = XGProxyDMatrixSetDataDense(proxy_dmat, array_str.c_str()); - R_SetExternalPtrProtected(handle, R_mat); } CHECK_CALL(res_code); R_API_END(); @@ -708,7 +707,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataCSR_R(SEXP handle, SEXP lst) { array_str_indices.c_str(), array_str_data.c_str(), ncol); - R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); @@ -722,7 +720,6 @@ XGB_DLL SEXP XGProxyDMatrixSetDataColumnar_R(SEXP handle, SEXP lst) { { std::string sinterface = MakeArrayInterfaceFromRDataFrame(lst); res_code = XGProxyDMatrixSetDataColumnar(proxy_dmat, sinterface.c_str()); - R_SetExternalPtrProtected(handle, lst); } CHECK_CALL(res_code); R_API_END(); @@ -736,20 +733,17 @@ struct _RDataIterator { SEXP f_reset; SEXP calling_env; SEXP continuation_token; - SEXP proxy_dmat; _RDataIterator( - SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token, SEXP proxy_dmat) : + SEXP f_next, SEXP f_reset, SEXP calling_env, SEXP continuation_token) : f_next(f_next), f_reset(f_reset), calling_env(calling_env), - continuation_token(continuation_token), proxy_dmat(proxy_dmat) {} + continuation_token(continuation_token) {} void reset() { - R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SafeExecFun(this->f_reset, this->calling_env, this->continuation_token); } int next() { - R_SetExternalPtrProtected(this->proxy_dmat, R_NilValue); SEXP R_res = Rf_protect( SafeExecFun(this->f_next, this->calling_env, this->continuation_token)); int res = Rf_asInteger(R_res); @@ -777,7 +771,7 @@ SEXP XGDMatrixCreateFromCallbackGeneric_R( int res_code; try { - _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token, proxy_dmat); + _RDataIterator data_iterator(f_next, f_reset, calling_env, continuation_token); std::string str_cache_prefix; xgboost::Json jconfig{xgboost::Object{}};