-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathml_case_study.Rmd
681 lines (512 loc) · 37.4 KB
/
ml_case_study.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
---
title: "Machine Learning Case Study"
author: "Danial Saef"
date: "03/05/2022"
output:
html_notebook:
number_sections: TRUE
bibliography: ml_case_study.bib
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
```
***
The goal of this notebook is to train a classifier that predicts the “target” variable.
For this, the three essential steps are:
1. exploratory data analysis
2. model training
3. discussion.
A small outlook of the structure:
The goal of the exploratory data analysis is to pre-identify key metrics that help us predict the target. We will start with some basics on data structure and column values, then move on to looking at the distrubtion of the numeric columns. We do this first on a higher level, and then on a more granular level. Afterwards we will investigate the time component of the data. These steps are necessary to decide, which model approach to use. The appendix contains an additional exploration of the time series conponent.
***
# Setting it all up {.tabset}
## Initialization
Before we dive into the analysis, let's load necessary packages, fix some settings, add functions we'll use and load the data.
```{r initialize, message = FALSE, warning = FALSE, results = "hide"}
#### install and load packages ####
libraries = c("knitr","lubridate", "rms", "markdown","mime","rmarkdown","tinytex","data.table","lattice","latticeExtra","Hmisc","DT","scales","ggplot2","forecast","rpart","rpart.plot","randomForest", "recipes", "caret", "mlbench", "themis", "ROSE", "mltools", "MLeval")
lapply(libraries, function(x) if (!(x %in% installed.packages())) {install.packages(x, dependencies = TRUE, repos = "http://cran.us.r-project.org")} )
invisible(lapply(libraries, library, quietly = TRUE, character.only = TRUE))
## ##
#### settings ####
Sys.setenv(LANG = "en") # set environment language to English
Sys.setlocale("LC_TIME", "en_US.UTF-8") # set timestamp language to English
Sys.setlocale("LC_TIME", "English") # set timestamp language to English
## ##
#### load data ####
data <- fread("data.csv")
## ##
```
## Basic functions
``` {r functions}
#### functions ####
plot_theme <- theme(panel.border = element_blank(),
axis.text = element_text(size = 20, face = "bold"),
axis.title = element_text(size = 24, face = "bold"),
strip.text = element_text(size = 20, face = "bold"),
plot.title = element_text(size = 24, face = "bold", hjust = .5),
panel.background = element_blank(), # bg of the panel
plot.background = element_blank(), # bg of the plot
panel.grid.major = element_blank(), # get rid of major grid
panel.grid.minor = element_blank(), # get rid of minor grid
legend.position = "none")
plot_theme_legend <- theme(panel.border = element_blank(),
axis.text = element_text(size = 20, face = "bold"),
axis.title = element_text(size = 24, face = "bold"),
strip.text = element_text(size = 20, face = "bold"),
plot.title = element_text(size = 24, face = "bold", hjust = .5),
panel.background = element_blank(), # bg of the panel
plot.background = element_blank(), # bg of the plot
panel.grid.major = element_blank(), # get rid of major grid
panel.grid.minor = element_blank(), # get rid of minor grid
legend.background = element_rect(fill = "transparent"), # get rid of legend bg
legend.box.background = element_rect(fill = "transparent")) # get rid of legend panel bg
## to get histograms for different variables and columns let's wrap it up in a small function for the sake of readability ## used in {r histogram plots} ##
get_hist <- function(DATA, CATEGORY = "all", HIST_COL = NA){
## separate if all categories are chosen or only a specific one (this is mainly for aesthetic reasons)
if (CATEGORY == "all"){
# subset the data to only include values with the specified TARGET and get the column of which we want to plot the histogram
dist_data <- DATA[, c(..HIST_COL, "target")]
dist_data[, target := as.factor(target)] # do this to make sure the target is a factor
names(dist_data) <- c("hist_col", "Target") # to pass it to ggplot2 (avoids writing functions for passing the externally provided variable name)
# plot the histogram
plot_histo <- ggplot(dist_data, aes(x=hist_col, group = Target, fill = Target)) +
geom_histogram(alpha=0.7, position="identity", color = "black") + # if yend = Poisson
plot_theme_legend +
labs(colour="Legend", x = "Value", y = "Count", title = paste("Category '", CATEGORY, "' | column '", HIST_COL, "'", sep = ""))
} else if (CATEGORY %in% c("a","b","c")) {
# subset the data to only include values with the specified TARGET and get the column of which we want to plot the histogram
dist_data <- DATA[categorical0 %in% CATEGORY, c(..HIST_COL, "target")]
dist_data[, target := as.factor(target)] # do this to make sure the target is a factor
names(dist_data) <- c("hist_col", "Target") # to pass it to ggplot2 (avoids writing functions for passing the externally provided variable name)
# plot the histogram
plot_histo <- ggplot(dist_data, aes(x=hist_col, group = Target, color = Target, fill = Target)) +
geom_histogram(alpha=0.7, position="identity", color = "black") + # if yend = Poisson
plot_theme_legend +
labs(colour="Legend", x = "Value", y = "Count", title = paste("Category '", CATEGORY, "' | column '", HIST_COL, "'", sep = ""))
# just an error handler
} else {
return("Error: invalid category, maybe a typo?")
}
return(plot_histo)
}
# For the sake of readability, let's write a short function to count the number of observations with target = 1 per time variable and print the output # used in {r create features} #
get_N_per_VAR <- function(VAR){
N_per_var <- data[,.N,by=get(VAR)] # count number of values for the passed variable VAR
N_pos_per_var <- data[target==1,.N,by=get(VAR)];setkey(N_pos_per_var,get)
N_per_var[, pct_target := round((N_pos_per_var[,N] / N ), digits = 2)]
names(N_per_var) <- c(VAR,"N", "pct_target") # make sure names are correct
setkeyv(N_per_var,VAR) # sort by the passed variable VAR
return(N_per_var)
}
create_train_test <- function(data, size = 0.8, train = TRUE) {
n_row = nrow(data)
total_row = size * n_row
train_sample <- 1:total_row
if (train == TRUE) {
return (data[train_sample, ])
} else {
return (data[-train_sample, ])
}
}
```
***
# Exploring the data {.tabset}
## Printing the data - or understanding what we're dealing with
As a first step, let's just print out the data table to understand what we're dealing with.
```{r explore}
nrow(data) # show number of rows
datatable(head(data,(nrow(data)*0.01))) # print data
```
Remember that we want to predict the variable "target". This seems to be a time series problem, but the observations are unordered. In total we have 10,000 observations. Good to know, but it probably makes sense to order the data by date and time.
## The data structure hints at a time series problem
Also, we would probably like to plot the data, but as of now, we can't make any judgement of what to plot ideally. Therefore, let's have a look at the structure of it. This might look a bit overwhelming at first, but we want to understand what columns the data has, of which type they are, and what kind of information they contain.
```{r str}
setkey(data, date, time) # order by date & time
str(data) # print out structure of the data
```
This just outputs what type each column has and prints the head of it. Also, we can see that the order by date & date is correct, from earliest to latest. Moreover, the time variable should be formatted to a timestamp format if we want to plot the data.
```{r describe}
html(describe(data)) # print out a description of the columns of the data
```
The above outputs a description of each column along with some summary statistics.
Combined with the structure of the data there are a few things we can already observe:
1. The target variable has only two values, either 0 or 1
+ seems to be a binary classification task in a time series
+ that makes it a bit tricky, we have to evaluate whether or not the time component actually plays a role
2. There are 3 different categories in "categorical0" a,b,c
+ they are almost perfectly balanced in terms of number of observations (N)
3. Column numeric0 could be some kind of index or counter?
+ it goes from 0 to 9,999, but does not have 10,000 distinct values
+ this is an interesting observation, so maybe we should plot that over time to understand how it is distributed
4. Column numeric1 could be some kind of intensity as it goes from 0 to 9?
+ it has many missing values, that's tricky to deal with
To get closer to identifying key features for predicting the target variable, we could plot the distribution of numerical0 and numerical1, separated by target values. It probably makes sense to also plot the data separated by categories a,b,c.
## Histograms tell us more about the differences between target states
The goal of plotting now is to understand the hierarchy of the data, because that's gonna help us come up with a good, but also explainable approach. At this point we could probably already set up a blackbox and make predictions, but that won't help our business stakeholders understand what's going on.
Now that we've set up everything to loop over the different combinations, we can print out the histograms. At this point, we already built some kind of decision tree for plotting histograms. In that sense, it seems as if a tree based approach could do the job as we're separating the data by categories and numeric columns.
The reason why we're doing this separation as a first step is that there are some dangers in the data:
1. Many missing values in "numeric1",
2. A time component of which we don't know if and how it's relevant.
To have more control over what we're going to model later, we will thus get a quick overview and see if we can already find some patterns.
### The plots {.tabset}
```{r histogram plots, results = "asis", message = FALSE, warning = FALSE}
## specify values of interest so that we can loop over them ##
unique_category <- c(data[,sort(unique(categorical0))]) # unique values of "categorical0" column
hist_cols <- c("numeric0", "numeric1") # the numeric columns
plots <- lapply(hist_cols, function(i) {
get_hist(data, CATEGORY = "all", HIST_COL = i)
})
for (i in 1:length(plots)) {
cat("#### Histogram",i,"\n")
print(plots[[i]])
cat('\n\n')
}
```
### The explanation
The plots show the distribution of values by 'target' state and per numerical column.
* Our takeaway here:
+ when state = 1 in 'numeric0', then values are mostly small and vice versa.
+ 'numeric1' seems less sensitive to the state.
* The problem:
+ while observing large 'numeric0' means that we're most likely in state 0, the opposite is not necessarily true
+ both numeric columns seem heavily imbalanced with largely state 0 observations
## Class imbalance: categories matter!
To illustrate the difference between categories, we will look at the number of values with state = 1 per category. Recall that every category has ~3300 observations.
```{r summary table}
datatable(data[target==1,.N, by = "categorical0"])
```
Clearly, if an observation is of category a, it is much more like to be in state 1 than b or c.
But are their characteristics significantly different?
## Histograms by category {.tabset}
```{r histogram by cat, results = "asis", message = FALSE, warning = FALSE}
hists <- lapply(hist_cols, function(i) {
lapply(unique_category, function(cat) {
get_hist(data, CATEGORY = cat, HIST_COL = i)
})
})
for (i in 1:length(hists)) {
for (j in 1:length(hists[[i]])) {
cat("### Histogram of column",hist_cols[i], "|", unique_category[j], "\n")
print(hists[[i]][[j]])
cat('\n\n')
}
}
```
## Histograms by category: takeaways
The plots show the distribution of values by 'target' state and per numerical column, separated by categories.
* The categories seem fairly similar in this aspect, but state 1 appears more often in a than in b or c in 'numeric0'
* In 'numeric1', the distribution doesn't seem to differ w.r.t. target state among categories. However, remember the large number of missing values, we should be careful inferring too much from this column!
The large class imbalance is a problem we will have to address. We could e.g. address it by:
1. Undersampling (oversampling) the majority (minority) class 'target' state 0 (state 1)
2. Passing prior weights to tree based approaches
## The time component: Does it make sense to use it?
We will now turn to investigating the time component to see how relevant it is. For this, we will have to create time series objects. First, let's check how many gaps the data has.
```{r time series}
# create a series from start date to end date to capture all possible dates #
start <- data[,min(date)]
end <- data[,max(date)]
start_to_end <- data.table("date" = seq(start,end,by = "1 day"))
# match it with our data and preallocate a column for missing dates #
start_to_end[data, target := i.target, on = "date"][, date_missing := 0]
start_to_end[is.na(target), date_missing := 1] # set the missing date column to '1' where the target value is NA
# plot all values where target = NA (we have a missing observation) #
datatable(start_to_end[,.N,by = "date_missing"])
```
The above table shows how often the date is missing. A value of 1 indicates here that the date is missing.
More often than not, we have gaps in the observations. There are two ways to approach this problem:
1. impute missing observations
2. extract other features from the date / time component
In our case, it probably makes more sense to go with the second approach, as our goal is to predict the state of any row given the information we have about different features. We do this because we're hoping to find some time related patterns.
## Time related patterns: we can't use the full granularity
Luckily, we have a relatively long time series. It is common practice to consider a few seasonalities when dealing with daily data:
1. Yearly
2. Quarterly
3. Monthly
4. Weekly
5. Weekdaily
6. Hourly
All of these could contain some useful information. In a regular time series, we could use some statistical techniques for identifying these seasonalities. In this irregular time series, we will use a different metric: The ratio of state 1 vs. state 0 values. This is an indicator for seasonal differences.
``` {r create features}
# create some columns for better understanding the time component (the names should be self-explanatory)
data[, date_time := lubridate::as_datetime(paste(date, time, sep = " "))]
data[, `:=` (year = year(date_time),
quarter = quarter(date_time),
month = month(date_time),
week = week(date_time),
weekday = weekdays(date_time),
hour = hour(date_time))]
# the weekday column is a bit tricky so we do this to make sure the days are ordered from Monday - Sunday
data[, weekday := ordered(weekday, levels=c("Monday", "Tuesday", "Wednesday", "Thursday",
"Friday", "Saturday", "Sunday"))]
# define all columns that we want to print out a table for #
data_cols <- c("year", "quarter", "month", "week", "weekday", "hour")
list_N_per_VAR <- lapply(data_cols, get_N_per_VAR)
```
### Observations per variable{.tabset}
``` {r print N per VAR, results = "asis", message = FALSE, warning = FALSE}
var_list <- list_N_per_VAR
for (i in 1:length(var_list)) {
cat("#### N per ", data_cols[i],"\n")
print(knitr::kable(var_list[[i]]))
cat('\n\n')
}
```
### Observations per variable: takeaways
The above shows the number of observations per feature and the percentage of observations where the target state was 1.
This kind of analysis is too extensive for a discussion with the business side. The main takeaways here are:
1. In earlier years, the variation seems higher than in more recent ones
2. Quarters don't seem to matter too much (little variation in pct_target)
3. There are differences in months, e.g. 3 vs 6
4. Weeks seem to matter, some weeks as low as 0.08 and others as high as 0.19
5. Weekdays don't seem to matter
6. Hour seems to matter
Some concerns related to the time component:
* We don't know the time zone, so what if categories or numeric1 values are identifiers for different countries that are in different time zones?
+ This would make hourly variable complicated to interpret.
* Maybe the state has something to do with holidays?
+ They change every year and are location specific
+ Additional information about the geographic location would yield additional information.
Finally, while there are some differences, it is unclear whether the magnitude is large enough to effectively make a difference when modeling.
## Summing it up: moving to tree models
We want to predict the target state of each observation based on date, time, category and two numerical columns.
* We identified that the date & time component matter, but the large number of missing dates indicates that a time series approach may not be the best choice, therefore we treat it as features instead,
+ If anything, we expect that year, month, week, and hour will matter the most
* The categories matter a lot, most target = 1 values were observed in category a,
+ We will have to address class imbalance
* The numeric columns are also differently distributed w.r.t. target state, however we should expect more explainability from column 'numeric0'.
A simple and often useful approach is to classify observations based on a decision tree. It is a great starting point, because it's a simple, intuitive and transparent method.
***
# Model fitting {.tabset}
## Decision Tree
With the insights gained, fitting a Decision Tree is now straight forward. We need to perform some data wrangling, account for the large difference in target vs non-target values, and split the data into a train and test set before we can fit a tree. We impute missing values by the mean of the data. When imputing, one should generally be careful, as this technique is only consistent when missing values are not informative [@josse_consistency_2020].
```{r fit tree, message = FALSE, warning = FALSE}
## For fitting the tree remove date / time columns as they yield no information ##
tree_data <- data[,!c("date", "time", "date_time")]
# Also, we impute missing values by the mean of the data #
mean_numeric_1 <- tree_data[,mean(numeric1, na.rm = TRUE)]
tree_data[is.na(numeric1), numeric1 := mean_numeric_1]
## we have already seen that we have way more 0 than 1 target values ##
# Therefore, we count the number of values by state #
DT_balance <- tree_data[,.N,by = "target"]
# Next, we will calculate the ratio of target vs non-target values, which can then be passed as priors to the tree #
N_target_0 <- DT_balance[target == 0,N] / nrow(tree_data)
N_target_1 <- 1 - N_target_0
weights <- c(N_target_1, N_target_0)
## Before we fit the model, we need to do a split between train and test data ##
## When there is no time component involved, it makes sense to take random samples ##
## However, in our case, we will split at 80% of the data and predict the last 20% ##
tree_data_train <- create_train_test(tree_data, 0.8, train = TRUE)
tree_data_test <- create_train_test(tree_data, 0.8, train = FALSE)
## Finally, we can fit the tree ##
tree_fit <- rpart(target~.,
data = tree_data_train,
parms = list(prior = weights), # priors due to class imbalance
method = "class") # method = "class" because target is either 0 or 1
tree_plot <- rpart.plot(tree_fit,
type = 4,
extra = 106)
```
The above plot shows the fitted Decision Tree from our training data. Given our exploratory analysis, this bears little surprises. The tree shows that the separation is first done by category, and then by 'numeric0' value. The tree did not consider other variables such as 'numeric1' or the time variables important enough.
Intuitively, the tree can be interpreted as follows:
* If an observation is of category a, it's most likely to be in state = 1
* If an observation is of category b,c, and the 'numeric0' value is <= 985, it's most likely in state = 1
* Else, an observation is most likely in state = 0
Based on these rules we will evaluate the predictive power of this model. Since our focus is on a balance between predictive accuracy and simplicity, the few rules should not be an issue for us.
To measure the accuracy, we will evaluate how many
* True Positive (TP, correctly predicted as 0)
* True Negative (TN, correctly predicted as 1)
* False Positive (FP, wrongly predicted as 0)
* False Negative (FN, wrongly predicted as 1)
we have. This can be depicted in a so called confusion matrix.
Based on the ratio of $\frac{TP + TN}{TP + TN + FP + FN}$ we can then calculate the accuracy.
```{r predict from tree, message = FALSE, warning = FALSE}
predict_unseen <- predict(tree_fit, tree_data_test, type = 'class')
confusionMatrix(tree_data_test[,as.factor(target)], predict_unseen)
```
The output shows the confusion matrix and the accuracy calculation. The table reads as follows:
* True Positives are in the upper left corner
* False Positives are in the lower left corner
* False Negatives are in the upper right corner
* True Negatives are in the lower right corner
This again bears little surprises:
We learned from the distribution of values, that observations in state 0 are mostly large and mostly in categories b & c. Therefore, it is easy to be accurate: less than 1% of values are misclassified.
We also saw that it's much harder to identify state 1 observations, since they're mainly small, but not always. Equally, state 0 observations are mainly large, but not always. The tree decided that if an observation is in category a, its correct label is most likely state = 1. However, only 1/3 of observations in category a are actually in state 1. Consequently, only roughly 1/3 of values are correctly classified.
## Random Forest
While we have an overall satisfactory accuracy due to the large number of correct classifications in the majority group state 0, we can try to improve accuracy and reduce the number of False Negatives by increasing model complexity. A typical extension of Decision Tree models is the Random Forest algorithm. While it is more of a blackbox approach than a Decision Tree, we hope that it offers some additional insights through the feature importance and that the additional complexity increases our model accuracy.
The basic idea of a Random Forest is to generate many Decision Trees through randomly drawing from the original data (Bootstrapping), and using only a subset of the total number of features available at each step. A committee of many classifiers then makes a majority vote on the class of each observation. This has the advantage that it reduces the variance of predictions through combining many trees and decorrelating it through subsetting the features.
``` {r random forest, message = FALSE, warning = FALSE}
# set random seed for reproducing results #
set.seed(1234)
rf_data <- copy(tree_data)
# refactor the target column as 0 / 1 values do not work with the algorithm in R
rf_data[,target := as.character(target)]
rf_data[target %in% "0",target := "A"]
rf_data[target %in% "1",target := "B"]
rf_data[,target := as.factor((target))]
# create folds #
cv_folds <- createFolds(rf_data[,target], k = 5, returnTrain = TRUE)
# create tune control with upsampling for handling class imbalance and 5-fold cross validation during training #
tuneGrid <- expand.grid(.mtry = c(1 : 10))
ctrl <- trainControl(method = "cv",
number = 5,
search = 'grid',
classProbs = TRUE,
savePredictions = "final",
index = cv_folds,
summaryFunction = twoClassSummary,
sampling = "up")
# specify tuning parameters #
ntrees <- c(100,500,1000)
nodesize <- c(1,5,10)
params <- expand.grid(ntrees = ntrees,
nodesize = nodesize)
# train the model in a grid search #
# this may take a while, in a real world scenario #
# we would ideally want to move this into a cloud environment #
# addtionally, this could be sped up using parallelization #
store_maxnode <- vector("list", nrow(params))
for(i in 1:nrow(params)){
nodesize <- params[i,2]
ntree <- params[i,1]
set.seed(123)
rf_model <- train(target~.,
data = rf_data,
method = "rf",
importance=TRUE,
metric = "ROC",
tuneGrid = tuneGrid,
trControl = ctrl,
ntree = ntree,
nodesize = nodesize)
store_maxnode[[i]] <- rf_model
}
# get unique names for experiments #
names(store_maxnode) <- paste("ntrees:", params$ntrees,
"nodesize:", params$nodesize)
# combine results and print output #
results_mtry <- resamples(store_maxnode)
summary(results_mtry)
```
This code chunk shows the tuning of the RF model and some performance metrics. We addressed class imbalance by upsampling the minority class. Other approaches would be e.g. downsampling the minority class or using algorithms such as "SMOTE" or "ROSE". In general, there is no consensus which method is the best practice, and it varies from use case to use case which approach should be preferred. We use 5-fold cross validation to evaluate the models. When using cross validation, we repeatedly divide the data into folds where part of the observations are being withheld during training and then used as test data. Note that in this case, we do not additionally make a train-test split. @diebold_comparing_2015 argues that (pseudo-out-of-sample) approaches are consistent only if the withheld data is asymptotically irrelevant. I.e. in small data cases, full-sample fitting is preferrable.
Instead of using accuracy as a performance metrics we turned to investigating the area under the Receiver Operator Curve, which is often a better way to evaluate the predictive power of a model.
It gives low scores both to random and to one class only classifiers. Additionally, we measure Sensitivity (True Positive Rate) and Specificity (True Negative Rate) that indicate the ability of a classifier to detect positive (negative) examples:
* $Sensitivity = \frac{TP}{AP},$
* $Specificity = \frac{TN}{AN},$
Where AP = All Positives and AN = All Negatives.
The results show that the model with `ntrees=500` and `nodesize=5` maintains a balance between sensitivity and specificity while having a good ROC value. In general, the ROC value of models doesn't seem to differ a lot, but sensitivity and specificity do vary a lot. Expectedly, these values get better with a higher number of trees. The biggest model with `ntrees=1000` however does not seem to perform significantly better than the one with `ntrees=500` Note that Random Forests aren't prone to overfitting, but the accuracy converges after a certain amount of trees and additional trees will provide little to none additional predictive power.
``` {r random forest varimp, message = FALSE, warning = FALSE}
# get variable importance #
plot(varImp(store_maxnode$`ntrees: 500 nodesize: 5`))
```
The feature importance plot confirms that the categories and the 'numeric0' are our main features of interest. Other variables only play a minor role. However, the meaning of this plot should not be overestimated as little is known about their theoretical properties [@scornet_consistency_2015].
``` {r random forest eval, message = FALSE, warning = FALSE}
# get performance metrics #
fit_eval <- evalm(store_maxnode$`ntrees: 500 nodesize: 5`, silent = TRUE, showplots = FALSE)
confusionMatrix(store_maxnode$`ntrees: 500 nodesize: 5`$pred$pred,store_maxnode$`ntrees: 500 nodesize: 5`$pred$obs)
## get roc curve plotted in ggplot2
fit_eval$roc
```
The above plot shows the ROC curve, as well as the area under the ROC curve (AUC). It visualizes the tradeoff in any classifier between True Positives and False Positives. A perfect classifier would be in the upper left corner. The $AUC-ROC$ score ranges from 0 to 1, where 1 is again a perfect classifier. The obtained score of 0.87 indicates that the classifier's predictive power is satisfactory. Finally, as work on the consistency and theoretical properties in general of the RF estimator is still in an early stage [@scornet_consistency_2015], any performance metrics can only be an indicator of a good model fit and the goal of using the ROC curve is not to obtain a perfect model, but one that serves our purpose well enough. We left the probability threshold at the default of 0.5. This value can be varied to either tune the algorithm towards sensitivity or specificity and helps us tackle class imbalance.
While we do not see a great improvement in terms of accuracy compared with the baseline Decision Tree, we have become better in predicting values in the minority class. The Random Forest model clearly performs better, but does not give us the desired outcome just yet.
## XGBoost
Finally, we will use XGBoost for making predictions. The concept of boosting is at a first glance very similar to that of Random Forests (or Bagging in general). A committee of "weak" classifiers (barely better than chance) is combined to make a majority vote on the predicted class of each observation. In contrast to Random Forests, this is however an iterative procedure. While iterating, the algorithm emphasizes missclassified observations in order to learn difficult patterns. Boosting has proven to be one of the most powerful classifiers of the last decade [@hastie_boosting_2009]. XGBoost is one of the most recent implementations, and we will use it in this example.
```{r predict from xgboost, message = FALSE, warning = FALSE}
# make train data set with one hot encoding (xboost only accepts numerical values) #
train <- tree_data[, !"target"]
cols <- c("categorical0", "year", "quarter", "month", "week", "weekday", "hour")
train[, (cols) := lapply(.SD, factor), .SDcols = cols]
train <- one_hot(train)
colnames_train <- names(train)
train <- matrix(as.numeric(unlist(train)), nrow = nrow(train))
colnames(train) <- colnames_train
# control parameters #
ctrl_xgb <- trainControl(method = "cv",
number = 5,
search = 'grid',
classProbs = TRUE,
savePredictions = "final",
index = cv_folds,
summaryFunction = twoClassSummary)
# calculate weights for observations #
xgb_weights <- ifelse(rf_data[,target] == "B",
table(rf_data[,target])[1]/nrow(rf_data),
table(rf_data[,target])[2]/nrow(rf_data))
# get tuning grid #
tuneGrid_xgb <- expand.grid(.nrounds = c(100, 250, 500),
.max_depth = c(1,3,6),
.eta = c(0.01,0.025,0.1,0.3),
.gamma = c(3),
.colsample_bytree = c(0.6,0.8,1),
.subsample = c(0.75),
.min_child_weight = c(1))
# fit the model #
# this again may take a while #
xgb_model <- train(x = train,
y = rf_data[,target],
method = "xgbTree",
trControl = ctrl_xgb,
tuneGrid = tuneGrid_xgb,
weights = xgb_weights,
verbose = TRUE,
metric = "ROC",
verbosity = 0,
allowParallel = TRUE)
xgb_model
```
Looking at the results we see that the AUC value is approximately similar to that of a Random Forest model. We picked some hyperparameters to adjust that work well in practice, and help us avoid overfitting:
* Number of features used (i.e. columns used): colsample_bytree.
+ Lower ratios avoid over-fitting.
* Ratio of training instances used (i.e. rows used): subsample.
+ Lower ratios avoid over-fitting.
* Maximum depth of a tree: max_depth.
+ Lower values avoid over-fitting.
* Minimum loss reduction required to make another split: gamma.
+ Larger values avoid over-fitting.
* Learning rate (i.e. how much we update our prediction with each successive tree): eta.
+ Lower values avoid over-fitting.
We identify the optimal parameter combination through performing a grid search as we did for the Random Forest algorithm. The most influential parameter in XGBoost is the learning rate. The learning rate is usually chosen in the region of 0.1-0.3, but smaller or larger values can be chosen depending on the use case.
```{r eval xgboost}
# get performance metrics #
xgb_eval <- evalm(xgb_model, silent = TRUE, showplots = FALSE)
confusionMatrix(xgb_model$pred$pred,xgb_model$pred$obs)
## get roc curve plotted in ggplot2
xgb_eval$roc
```
The confusion matrix shows that the Boosting algorithm is capable of making powerful predictions. We managed to strongly increase Specifity while losing some Sensitivity This means that the classifier has become much better at predicting values from the minority class at the cost of losing predictive power in the majority class. Looking at the ROC curve we find that nothing has changed. The results indicate that XGBoost performs best among the tested algorithms in predicting minority class values, but isn't necessarily the best classifier in case we're interested in predicting the majority class.
***
# Discussion
The goal of this analysis was to predict whether a value is target / non-target based on numeric values, category, date & time. While this seemed like a time series problem at first, any time series method would be overtly complicated in comparison to extracting time based features.
The exploratory analysis revealed that the category and the column 'numeric0' were expected to be our main predictors. We expected limited information from the time based variables and column 'numeric1' due to limited variation between categories (time based variables & 'numeric1'), and large amount of missing values ('numeric1'). Additionally, the analysis revealed a major class imbalance. Only ca. 13% of observations were of target state = 1. This is a problem for classification, as we would ideally want classes to be balanced.
Tree based models are an intuitive choice in many business problems, as they are explainable and simple. Therefore, we fitted a Decision Tree, which classified observations based on category and 'numeric0' value.
While it was moderately accurate, it was very biased towards the majority state. To address this issue, we increased model complexity, and tested whether the Random Forest algorithm could fix the issue. The feature importance plot confirmed that mainly the categories and 'numeric0' are relevant for classifying the target value.
The area under the ROC curve suggests that we have a satisfactory model fit, however we had to find a balance between sensitivity and specificity. While better than a Decision Tree, the predictive power for the minority class was still barely higher than a random guess.
XGBoost yielded the best performance in predicting the minority class, but not necessarily when predicting the majority class. While it did not increase the Accuracy or the AUC largely, it had a significantly higher Specificity while also maintaining a satisfactory Sensitivity value.
The tested models can only serve as a baseline for more sophisticated approaches and their predictive power can still be improved through further hyperparameter tuning, adjusting the probability threshold for classification, comparing methods for missing data imputation, and comparing different sampling approaches to address the class imbalance. Furthermore, parameter / model changes usually represent a trade off between false positives and false negatives. In any real world use case it should be decided which of these values is less costly. Accordingly, a suitable performance metric can then be chosen to compare models.
***
# Appendix: treating the numeric variables as time series
Time series experts might ask why we are not focussing on the time component.
```{r plot numeric0}
## Plot column numeric0 when in state 0
ggplot(data[target == 0], aes(x = date_time, y = numeric0)) +
geom_point(size = 0.5, color = "darkblue") +
geom_line(size = 0.1, color = "darkblue") +
labs(x = "Date", y = "Value", title = "Evolution 'numeric0' | state = 0") +
plot_theme
## Plot column numeric0 when in state 1
ggplot(data[target == 1], aes(x = date_time, y = numeric0)) +
geom_point(size = 0.5, color = "darkblue") +
geom_line(size = 0.1, color = "darkblue") +
labs(x = "Date", y = "Value", title = "Evolution 'numeric0' | state = 1") +
plot_theme
acf(data[target == 1, numeric0], main = "Autcorrelation function | 'numeric0' | state = 1")
```
The above plots show the evolution of the column 'numeric0' when in state 0 and 1, plus the autocorrelation function of values in state 1.
These are the major observations:
1. Clearly, the different states follow a different law
2. The distributions we discussed earlier reflect the tendency to high / low values in the states
3. While values in state 0 seem to follow a white noise like pattern, state 1 has some time dependency involved
4. The autocorrelation in the numerical columns gives a hint that there could be autocorrelation in states too
As interesting as these points are, they do not explain why the target switches from 0 to 1.
***
# References