Skip to content

Commit

Permalink
Merge pull request #380 from OHDSI/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jreps authored Mar 14, 2023
2 parents ee03f69 + 52cee33 commit 66d3cf1
Show file tree
Hide file tree
Showing 12 changed files with 610 additions and 10 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: PatientLevelPrediction
Type: Package
Title: Developing patient level prediction using data in the OMOP Common Data
Model
Version: 6.2.1
Version: 6.3.1
Date: 2023-02-28
Authors@R: c(
person("Jenna", "Reps", email = "[email protected]", role = c("aut", "cre")),
Expand Down Expand Up @@ -68,7 +68,8 @@ Suggests:
survminer,
testthat,
withr,
xgboost (> 1.3.2.1)
xgboost (> 1.3.2.1),
lightgbm
Remotes:
ohdsi/BigKnn,
ohdsi/Eunomia,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ export(setGradientBoostingMachine)
export(setIterativeHardThresholding)
export(setKNN)
export(setLassoLogisticRegression)
export(setLightGBM)
export(setMLP)
export(setNaiveBayes)
export(setPythonEnvironment)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
PatientLevelPrediction 6.3.1
======================
- fixed bug with multiple covariate settings in diagnose plp
- added min cell count when exporting database results to csv files
- light GBM added (thanks Jin Choi and Chungsoo Kim)
- fixed minor bugs when uploading results to database

PatientLevelPrediction 6.2.1
======================
- added ensure_installed("ResultModelManager") to getDataMigrator()
Expand Down
8 changes: 4 additions & 4 deletions R/AdditionalCovariates.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ getCohortCovariateData <- function(
"}} as covariate_value",
"from @cohort_temp_table a inner join @covariate_cohort_schema.@covariate_cohort_table b",
" on a.subject_id = b.subject_id and ",
" b.cohort_start_date <= dateadd(day, @endDay, a.cohort_start_date) and ",
" b.cohort_start_date <= dateadd(day, @endDays, a.cohort_start_date) and ",
" b.cohort_end_date >= dateadd(day, @startDay, a.cohort_start_date) ",
"{@ageInteraction | @lnAgeInteraction}?{inner join @cdm_database_schema.person p on p.person_id=a.subject_id}",
"where b.cohort_definition_id = @covariate_cohort_id
Expand All @@ -74,7 +74,7 @@ getCohortCovariateData <- function(
row_id_field = rowIdField,
startDay = covariateSettings$startDay,
covariate_id = covariateSettings$covariateId,
endDay = covariateSettings$endDay,
endDays = covariateSettings$endDays,
countval = covariateSettings$count,
ageInteraction = covariateSettings$ageInteraction,
lnAgeInteraction = covariateSettings$lnAgeInteraction,
Expand Down Expand Up @@ -102,7 +102,7 @@ getCohortCovariateData <- function(
concept_set = paste('Cohort_covariate during day',
covariateSettings$startDay,
'through',
covariateSettings$endDay,
covariateSettings$endDays,
'days relative to index:',
ifelse(covariateSettings$count, 'Number of', ''),
covariateSettings$covariateName,
Expand Down Expand Up @@ -193,7 +193,7 @@ createCohortCovariateSettings <- function(
cohortTable = cohortTable,
cohortId = cohortId,
startDay = startDay,
endDay = endDay,
endDays = endDay,
count = count,
ageInteraction = ageInteraction,
lnAgeInteraction = lnAgeInteraction,
Expand Down
25 changes: 21 additions & 4 deletions R/DiagnosePlp.R
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,23 @@ probastParticipants <- function(
}


getMaxEndDaysFromCovariates <- function(covariateSettings){

if(inherits(covariateSettings, 'covariateSettings')){
covariateSettings <- list(covariateSettings)
}

vals <- unlist(lapply(covariateSettings, function(x){x$endDays}))

if(length(vals) == 0){
return(0)
} else{
return(max(vals))
}
}



probastPredictors <- function(
plpData,
outcomeId,
Expand All @@ -549,7 +566,7 @@ probastPredictors <- function(
# covariate + outcome correlation; km of outcome (close to index or not)?
probastId <- '2.2'
if(populationSettings$startAnchor == 'cohort start'){
if(populationSettings$riskWindowStart > plpData$metaData$covariateSettings$endDays){
if(populationSettings$riskWindowStart > getMaxEndDaysFromCovariates(plpData$metaData$covariateSettings)){
diagnosticAggregate <- rbind(
diagnosticAggregate,
c(probastId, 'Pass')
Expand Down Expand Up @@ -632,9 +649,9 @@ probastPredictors <- function(
# 2.3.1
# cov end_date <=0
probastId <- '2.3'
if(plpData$metaData$covariateSettings$endDays <= 0){
if(getMaxEndDaysFromCovariates(plpData$metaData$covariateSettings) <= 0){

if(plpData$metaData$covariateSettings$endDays < 0){
if(getMaxEndDaysFromCovariates(plpData$metaData$covariateSettings) < 0){
diagnosticAggregate <- rbind(
diagnosticAggregate,
c(probastId, 'Pass')
Expand Down Expand Up @@ -692,7 +709,7 @@ probastOutcome <- function(
# 3.6 - check tar after covariate end_days
probastId <- '3.6'
if(populationSettings$startAnchor == 'cohort start'){
if(populationSettings$riskWindowStart > plpData$metaData$covariateSettings$endDays){
if(populationSettings$riskWindowStart > getMaxEndDaysFromCovariates(plpData$metaData$covariateSettings)){
diagnosticAggregate <- rbind(
diagnosticAggregate,
c(probastId, 'Pass')
Expand Down
236 changes: 236 additions & 0 deletions R/LightGBM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# @file LightGBM.R
# Copyright 2023 Observational Health Data Sciences and Informatics
#
# This file is part of PatientLevelPrediction
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Create setting for gradient boosting machine model using lightGBM (https://github.com/microsoft/LightGBM/tree/master/R-package).
#'
#' @param nthread The number of computer threads to use (how many cores do you have?)
#' @param earlyStopRound If the performance does not increase over earlyStopRound number of trees then training stops (this prevents overfitting)
#' @param numIterations Number of boosting iterations.
#' @param numLeaves This hyperparameter sets the maximum number of leaves. Increasing this parameter can lead to higher model complexity and potential overfitting.
#' @param maxDepth This hyperparameter sets the maximum depth . Increasing this parameter can also lead to higher model complexity and potential overfitting.
#' @param minDataInLeaf This hyperparameter sets the minimum number of data points that must be present in a leaf node. Increasing this parameter can help to reduce overfitting
#' @param learningRate This hyperparameter controls the step size at each iteration of the gradient descent algorithm. Lower values can lead to slower convergence but may result in better performance.
#' @param lambdaL1 This hyperparameter controls L1 regularization, which can help to reduce overfitting by encouraging sparse models.
#' @param lambdaL2 This hyperparameter controls L2 regularization, which can also help to reduce overfitting by discouraging large weights in the model.
#' @param scalePosWeight Controls weight of positive class in loss - useful for imbalanced classes
#' @param isUnbalance This parameter cannot be used at the same time with scalePosWeight, choose only one of them. While enabling this should increase the overall performance metric of your model, it will also result in poor estimates of the individual class probabilities.
#' @param seed An option to add a seed when training the final model
#'
#' @examples
#' model.lightgbm <- setLightGBM(
#' numLeaves = c(20, 31, 50), maxDepth = c(-1, 5, 10),
#' minDataInLeaf = c(10, 20, 30), learningRate = c(0.05, 0.1, 0.3)
#' )
#'
#' @export
setLightGBM <- function(nthread = 20,
earlyStopRound = 25,
numIterations = c(100),
numLeaves = c(31),
maxDepth = c(5, 10),
minDataInLeaf = c(20),
learningRate = c(0.05, 0.1, 0.3),
lambdaL1 = c(0),
lambdaL2 = c(0),
scalePosWeight = 1,
isUnbalance = FALSE,
seed = sample(10000000, 1)) {
ensure_installed("lightgbm")
checkIsClass(seed, c("numeric", "integer"))

if (length(nthread) > 1) {
stop("nthread must be length 1")
}
if (!inherits(x = seed, what = c("numeric", "integer"))) {
stop("Invalid seed")
}
if(sum(numIterations < 1) > 0){
stop('numIterations must be greater that 0')
}
if(sum(numLeaves < 2) > 0){
stop('numLeaves must be greater that 1')
}
if(sum(numLeaves > 131072) > 0){
stop('numLeaves must be less that or equal 131072')
}
if(sum(learningRate <= 0) > 0){
stop('learningRate must be greater that 0')
}
if (sum(lambdaL1 < 0) > 0){
stop('lambdaL1 must be 0 or greater')
}
if (sum(lambdaL2 < 0) > 0){
stop('lambdaL2 must be 0 or greater')
}
if (sum(scalePosWeight < 0) > 0){
stop('scalePosWeight must be 0 or greater')
}
if (isUnbalance == TRUE & sum(scalePosWeight != 1) > 0){
stop('isUnbalance cannot be used at the same time with scale_pos_weight != 1, choose only one of them')
}

paramGrid <- list(
earlyStopRound = earlyStopRound,
numIterations = numIterations,
numLeaves = numLeaves,
maxDepth = maxDepth,
minDataInLeaf = minDataInLeaf,
learningRate = learningRate,
lambdaL1 = lambdaL1,
lambdaL2 = lambdaL2,
isUnbalance = isUnbalance,
scalePosWeight = scalePosWeight
)

param <- listCartesian(paramGrid)

attr(param, "settings") <- list(
modelType = "LightGBM",
seed = seed[[1]],
modelName = "LightGBM",
threads = nthread[1],
varImpRFunction = "varImpLightGBM",
trainRFunction = "fitLightGBM",
predictRFunction = "predictLightGBM"
)

attr(param, "saveType") <- "lightgbm"

result <- list(
fitFunction = "fitRclassifier",
param = param
)

class(result) <- "modelSettings"

return(result)
}



varImpLightGBM <- function(model,
covariateMap) {
varImp <- lightgbm::lgb.importance(model, percentage = T) %>% dplyr::select("Feature", "Gain")

varImp <- data.frame(
covariateId = gsub(".*_","",varImp$Feature),
covariateValue = varImp$Gain,
included = 1
)

varImp <- merge(covariateMap, varImp, by.x = "columnId", by.y = "covariateId")
varImp <- varImp %>%
dplyr::select("covariateId", "covariateValue", "included")

return(varImp)
}

predictLightGBM <- function(plpModel,
data,
cohort) {
if (inherits(data, "plpData")) {
# convert
matrixObjects <- toSparseM(
plpData = data,
cohort = cohort,
map = plpModel$covariateImportance %>%
dplyr::select("columnId", "covariateId")
)

# use the include??

newData <- matrixObjects$dataMatrix
cohort <- matrixObjects$labels
} else {
newData <- data
}

if (inherits(plpModel, "plpModel")) {
model <- plpModel$model
} else {
model <- plpModel
}

pred <- data.frame(value = stats::predict(model, newData))
prediction <- cohort
prediction$value <- pred$value

prediction <- prediction %>%
dplyr::select(-"rowId") %>%
dplyr::rename(rowId = "originalRowId")

attr(prediction, "metaData") <- list(modelType = attr(plpModel, "modelType"))

return(prediction)
}

fitLightGBM <- function(dataMatrix,
labels,
hyperParameters,
settings) {
if (!is.null(hyperParameters$earlyStopRound)) {
trainInd <- sample(nrow(dataMatrix), nrow(dataMatrix) * 0.9)
train <- lightgbm::lgb.Dataset(
data = dataMatrix[trainInd, , drop = F],
label = labels$outcomeCount[trainInd]
)
test <- lightgbm::lgb.Dataset(
data = dataMatrix[-trainInd, , drop = F],
label = labels$outcomeCount[-trainInd]
)
watchlist <- list(train = train, test = test)
} else {
train <- lightgbm::lgb.Dataset(
data = dataMatrix,
label = labels$outcomeCount,
free_raw_data = FALSE,
)
watchlist <- list()
}

outcomes <- sum(labels$outcomeCount > 0)
N <- nrow(labels)
outcomeProportion <- outcomes / N
set.seed(settings$seed)
model <- lightgbm::lgb.train(
data = train,
params = list(
objective = "binary",
boost = "gbdt",
metric = "auc",
num_iterations = hyperParameters$numIterations,
num_leaves = hyperParameters$numLeaves,
max_depth = hyperParameters$maxDepth,
learning_rate = hyperParameters$learningRate,
feature_pre_filter=FALSE,
min_data_in_leaf = hyperParameters$minDataInLeaf,
scale_pos_weight = hyperParameters$scalePosWeight,
lambda_l1 = hyperParameters$lambdaL1,
lambda_l2 = hyperParameters$lambdaL2,
seed = settings$seed,
is_unbalance = hyperParameters$isUnbalance,
max_bin = 255,
num_threads = settings$threads
),
verbose = 1,
early_stopping_rounds = hyperParameters$earlyStopRound,
valids = watchlist
# categorical_feature = 'auto' # future work
)

return(model)
}
Loading

0 comments on commit 66d3cf1

Please sign in to comment.