Skip to content

Commit

Permalink
check data spliting for cal_feature_sel in trans_classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiLiubio committed Jan 31, 2025
1 parent 001a15c commit 51cc306
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions R/trans_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier",
boruta.repetitions = 4,
...
){
self <- private$check_training_data(self)

data_input <- self$data_train
data_x <- data_input[, -1]
data_y <- data_input[, 1]
Expand Down Expand Up @@ -215,11 +217,14 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier",
data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)]
self$data_train <- data_output

data_input <- self$data_test
data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)]
self$data_test <- data_output

message("Selected features are reassigned to object$data_train and object$data_test ...")
if(is.null(self$data_test)){
message("Selected features are reassigned to object$data_train ...")
}else{
data_input <- self$data_test
data_output <- data_input[, c(colnames(data_input)[1], boruta.list.top)]
self$data_test <- data_output
message("Selected features are reassigned to object$data_train and object$data_test ...")
}
invisible(self)
},
#' @description
Expand Down Expand Up @@ -279,15 +284,9 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier",
ntree = 500,
...
){
self <- private$check_training_data(self)
train_data <- self$data_train
if(is.null(train_data)){
message("No training data is found! The reason is function cal_split is not performed! Use all the samples for the training ...")
train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE)
if(self$type == "Classification"){
train_data$Response %<>% as.factor
}
self$data_train <- train_data
}

trControl <- self$trainControl
if(is.null(trControl)){
trControl <- caret::trainControl()
Expand Down Expand Up @@ -775,6 +774,20 @@ trans_classifier <- R6::R6Class(classname = "trans_classifier",
p
}
),
private = list(
check_training_data = function(self){
train_data <- self$data_train
if(is.null(train_data)){
message("No training data is found! The reason is that the cal_split function is not performed! Use all the samples for the training ...")
train_data <- data.frame(Response = self$data_response, self$data_feature, check.names = FALSE)
if(self$type == "Classification"){
train_data$Response %<>% as.factor
}
self$data_train <- train_data
}
self
}
),
lock_class = FALSE,
lock_objects = FALSE
)
Expand Down

0 comments on commit 51cc306

Please sign in to comment.