From 51cc30666f362105a314289ce7496248f277ea7a Mon Sep 17 00:00:00 2001 From: ChiLiubio Date: Fri, 31 Jan 2025 21:11:47 +0800 Subject: [PATCH] check data spliting for cal_feature_sel in trans_classifier --- R/trans_classifier.R | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/R/trans_classifier.R b/R/trans_classifier.R index d667394..1c428ef 100644 --- a/R/trans_classifier.R +++ b/R/trans_classifier.R @@ -185,6 +185,8 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier", boruta.repetitions = 4, ... ){ + self <- private$check_training_data(self) + data_input <- self$data_train data_x <- data_input[, -1] data_y <- data_input[, 1] @@ -215,11 +217,14 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier", data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)] self$data_train <- data_output - data_input <- self$data_test - data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)] - self$data_test <- data_output - - message("Selected features are reassigned to object$data_train and object$data_test ...") + if(is.null(self$data_test)){ + message("Selected features are reassigned to object$data_train ...") + }else{ + data_input <- self$data_test + data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)] + self$data_test <- data_output + message("Selected features are reassigned to object$data_train and object$data_test ...") + } invisible(self) }, #' @description @@ -279,15 +284,9 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier", ntree = 500, ... ){ + self <- private$check_training_data(self) train_data <- self$data_train - if(is.null(train_data)){ - message("No training data is found! The reason is function cal_split is not performed! Use all the samples for the training ...") - train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE) - if(self$type == "Classification"){ - train_data$Response %<>% as.factor - } - self$data_train <- train_data - } + trControl <- self$trainControl if(is.null(trControl)){ trControl <- caret::trainControl() @@ -775,6 +774,20 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier", p } ), + private = list( + check_training_data = function(self){ + train_data <- self$data_train + if(is.null(train_data)){ + message("No training data is found! The reason is that the cal_split function is not performed! Use all the samples for the training ...") + train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE) + if(self$type == "Classification"){ + train_data$Response %<>% as.factor + } + self$data_train <- train_data + } + self + } + ), lock_class = FALSE, lock_objects = FALSE )