Skip to content

Commit

Permalink
Add dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
bbimber committed Feb 5, 2024
1 parent cb0f80b commit d5775a7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Imports:
BiocParallel,
ggcorrplot,
magrittr,
Matrix
mlr3verse
Suggests:
devtools,
testthat (>= 2.1.0),
Expand Down
2 changes: 1 addition & 1 deletion R/CellTypist.R
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ RunCellTypist <- function(seuratObj, modelName = "Immune_All_Low.pkl", pThreshol

# Cell typist expects a single column:
tbl <- utils::read.table(geneFile, sep = '\t')
write.table(tbl$V1, file = geneFile, row.names = FALSE, col.names = FALSE)
utils::write.table(tbl$V1, file = geneFile, row.names = FALSE, col.names = FALSE)

# Ensure models present:
if (updateModels) {
Expand Down
42 changes: 21 additions & 21 deletions R/Classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,34 +101,34 @@ TrainModel <- function(training_matrix, celltype, hyperparameter_tuning = F, lea
learner <- mlr3::lrn("classif.ranger", importance = "permutation", predict_type = "prob")

#Define Ranger Hyperparameter Space (RandomBotv2)
tune_ps <- ps(
num.trees = p_int(lower = 10, upper = 2000),
sample.fraction = p_dbl(lower = 0.1, upper = 1),
respect.unordered.factors = p_fct(levels = c("ignore", "order", "partition")),
min.node.size = p_int(lower = 1, upper = 100),
splitrule = p_fct(levels = c("gini", "extratrees")),
num.random.splits = p_int(lower = 1, upper = 100, depends = splitrule == "extratrees")
tune_ps <- mlr3verse::ps(
num.trees = mlr3verse::p_int(lower = 10, upper = 2000),
sample.fraction = mlr3verse::p_dbl(lower = 0.1, upper = 1),
respect.unordered.factors = mlr3verse::p_fct(levels = c("ignore", "order", "partition")),
min.node.size = mlr3verse::p_int(lower = 1, upper = 100),
splitrule = mlr3verse::p_fct(levels = c("gini", "extratrees")),
num.random.splits = mlr3verse::p_int(lower = 1, upper = 100, depends = splitrule == "extratrees")
)
} else if (learner == "classif.xgboost"){
#Update task
task <- mlr3::TaskClassif$new(classification.data, id = "CellTypeBinaryClassifier", target = "celltype_binary")
#Define learner
learner <- mlr3::lrn("classif.xgboost", predict_type = "prob")
#Define XGBoost model's Hyperparameter Space (RandomBotv2)
tune_ps <- ps(
booster = p_fct(levels = c("gblinear", "gbtree", "dart")),
nrounds = p_int(lower = 2, upper = 8, trafo = function(x) as.integer(round(exp(x)))),
eta = p_dbl(lower = -4, upper = 0, trafo = function(x) 10^x),
gamma = p_dbl(lower = -5, upper = 1, trafo = function(x) 10^x),
lambda = p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
alpha = p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
subsample = p_dbl(lower = 0.1, upper = 1),
max_depth = p_int(lower = 1, upper = 15),
min_child_weight = p_dbl(lower = -1, upper = 0, trafo = function(x) 10^x),
colsample_bytree = p_dbl(lower = 0.1, upper = 1),
colsample_bylevel = p_dbl(lower = 0.1, upper = 1),
rate_drop = p_int(lower = 0, upper = 1, depends = booster == 'dart'),
skip_drop = p_int(lower = 0, upper = 1, depends = booster == 'dart')
tune_ps <- mlr3verse::ps(
booster = mlr3verse::p_fct(levels = c("gblinear", "gbtree", "dart")),
nrounds = mlr3verse::p_int(lower = 2, upper = 8, trafo = function(x) as.integer(round(exp(x)))),
eta = mlr3verse::p_dbl(lower = -4, upper = 0, trafo = function(x) 10^x),
gamma = mlr3verse::p_dbl(lower = -5, upper = 1, trafo = function(x) 10^x),
lambda = mlr3verse::p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
alpha = mlr3verse::p_dbl(lower = -4, upper = 3, trafo = function(x) 10^x),
subsample = mlr3verse::p_dbl(lower = 0.1, upper = 1),
max_depth = mlr3verse::p_int(lower = 1, upper = 15),
min_child_weight = mlr3verse::p_dbl(lower = -1, upper = 0, trafo = function(x) 10^x),
colsample_bytree = mlr3verse::p_dbl(lower = 0.1, upper = 1),
colsample_bylevel = mlr3verse::p_dbl(lower = 0.1, upper = 1),
rate_drop = mlr3verse::p_int(lower = 0, upper = 1, depends = booster == 'dart'),
skip_drop = mlr3verse::p_int(lower = 0, upper = 1, depends = booster == 'dart')
)
}
}
Expand Down

0 comments on commit d5775a7

Please sign in to comment.