-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_xgb.R
46 lines (33 loc) · 1.05 KB
/
train_xgb.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
train_xgb <- function(form, .data, ...) {
# Target and features
target <- formula.tools::lhs(form)
features <- labels(terms(form))
# Separate X/y
y <- .data %>% dplyr::select(!!target)
X <- .data %>% dplyr::select(!!features)
# Create final matrices needed for xgBoost:
train_matrix <- xgb.DMatrix(data = as.matrix(X),
label = as.matrix(y))
# Parameters
params = list(
"booster" = "gbtree",
"objective" = "reg:linear",
"eval_metric" = "mae"
)
# Fit model
mod <- xgboost::xgb.train(params = params,
data = train_matrix,
nround = 100)
# Return fitted obj
return(list(mod = mod,
target = target,
features = features))
}
predict_xgb <- function(model, newdata) {
# Feature matrix
X_new <- newdata %>% dplyr::select(!!model$features)
# Create matrix required by xgboost
data_new_matrix <- xgb.DMatrix(data = as.matrix(X_new))
pred <- predict(model$mod, data_new_matrix)
return(pred)
}