Decision Trees Pt. 1

Written on December 5, 2017

So far, we’ve looked at parametric models. These are models that have set assumptions that are required to make them statistically valid. There are other classes of bad-boy type models that don’t follow the rules, and don’t really care. One of the more popular types that we’ll go over today is the basic decision tree. This lesson is going to cover the basics of a simple decision tree model, but we’ll expand on it much more in the near future.

The Basics of a Decision Tree

A decision tree is a pretty simple model that can be applied to both regression and classification problems. The general idea is that we split up our data on a set variable based on how we can best determine our response, y. The split is determined at each step by what leads to the largest reduction in RSS for regression and the most pure nodes for classification (purity essentially measures our buckets of predictions on their mix of classes; buckets that contain mostly one class are said to be more pure than one that has a mix of multiple classes).

Let’s say we have a binary classification problem; we’re trying to predict whether a team wins or losses a game based on, oh let’s say field goal percentage. The best split of our data to predict y might be 55%, so if a team shot 55% or higher, they’d go down one branch where we would predict that team wins and teams that shot less than 55% would go down the other branch where we would predict they lose. If we stopped it there, we could look at the terminal node, or leaf, and see how right our predictions were.

See where the tree terminology comes from? And we can expand this simple one split model to multiple splits with more than two leaves. We might not want to say all teams that under 55% lose; we could make another split after the first one that separates the data based on whether the team shot above or below 45%. We can also throw in more variables and split on those.

The same general idea can work for regression problems. Everything in each of the terminal nodes would have a prediction based on the mean of the terminal node it ends up in.

Now, before we see an actual example, I want to warn about overfitting opportunities with this model. It’s easy to see that our decision tree could continue to make miniscule splits to our data to the point where we have a crazy amount of branches. We’ll need to prune our tree back, removing branches and using some form of cross validation to determine a good number of splits that doesn’t overfit.

The Data

We’ll be predicting wins by the Pacers in the 2016-17 season similarly to how we did so in the logistic regression blog. It’ll be a good way to compare these different methods. We can scrape the data we’ll be using from basketball-reference. We’ll also split the data into a training and testing set using createDataPartition() in caret. This method will allow us to train the model on 70% of our data while using the other 30% to see how our model reacts to new data. The set.seed() function makes sure that you get the same split that I do.

library(rvest)
library(caret)
library(dplyr)

pacers<-
  "http://www.basketball-reference.com/teams/IND/2017/gamelog/" %>%
  read_html('#tgl_basic') %>%
  html_table() %>%
  #removing the scraped dataframe from a list
  .[[1]] %>%
  #selecting specific columns
  .[,2:23] %>%
  #changing column names to names in the first row
  `colnames<-` (make.names(.[1,], unique=T)) %>%
  #removing excess headers in dataframe
  filter(Date!="Date" & Date!="") %>%
  rename(H.A=X, Opp.Score=Opp.1) %>%
  #creating a home and away column
  mutate(H.A=ifelse(H.A=="@", "A", "H")) %>%
  #changing variables either to factors or numbers
  mutate_at(funs(as.numeric), .vars=vars(Tm:TOV)) %>%
  mutate_at(funs(factor), .vars=vars(W.L, H.A))

set.seed(1234)
split<- createDataPartition(y = pacers$W.L, times = 2, p = .7, list = F)
training<- pacers[split,]
testing<- pacers[-split,]

Creating a Tree

Now, we’ll be using the rpart package to develop this decision tree, so install it and load it into your R session. There are multiple packages for this though, so if you don’t like rpart explore with some other options.

The way we’ll develop our decision tree model is through the rpart() function. We’ll begin by just specifying the model formula, the method, and the data. The only argument that is new is method; this specifies whether we want classification or regression trees. We’re using classification in this case, so we’ll specify method="class". The default split criteria is the Gini Index which is found by taking the sum of the squared predicted probabilities and subtracting it from one.

library(rpart)

set.seed(1234)
simple.dt<- rpart(W.L ~ FG. + TRB + H.A + TOV, method="class", data=training)

We can plot our tree by calling plot() on our tree, but this plot can be a bit cluttered and hard to understand. We can instead use rpart.plot() from the rpart.plot package to produce a nicer, easier to understand plot.

library(rpart.plot)

rpart.plot(simple.dt)

We can see the flow of this decision tree relatively easily. The first node captures all the observations and shows the result, the percentage of said result in the node, and the percentage of observations in the node. So the first node shows win as the output, 52% of all observations in the node are wins, and 100% of the data is included in the node.

The first split is if the Pacers shot less than 51%. If they didn’t, we predict they win. If they did, we move to the next step, which is if they had greater than or equal to 10 turnovers. We move throughout the tree and make predictions this way. We can see at our prediction success at each of the terminal nodes (for example, we were 83% accurate by predicting wins if the team shot less than 51% and had less than 10 turnovers).

Decision trees run the risk of overfitting data as nodes become more and more distinct. To prevent this, it’s necessary to look at results and prune back our tree branches. We can look at model results that will help us gauge overfit by calling the printcp() function. CP is the complexity parameter and acts as a penalty that controls tree size. Smaller CP values mean more complex trees (i.e. more splits and terminal nodes).

printcp(simple.dt)

## 
## Classification tree:
## rpart(formula = W.L ~ FG. + TRB + H.A + TOV, data = training, 
##     method = "class")
## 
## Variables actually used in tree construction:
## [1] FG. TOV TRB
## 
## Root node error: 56/116 = 0.48276
## 
## n= 116 
## 
##      CP nsplit rel error  xerror     xstd
## 1 0.250      0      1.00 1.14286 0.095648
## 2 0.125      2      0.50 0.71429 0.091415
## 3 0.010      4      0.25 0.37500 0.074055

rpart automatically stops splitting the tree if the fit does not improve by the CP value in an effort to save computing time. This, along with several other parameters (such as minimum number of observations needed for a split, number of cross validation folds, etc.), can be edited within the rpart.control() function. If you wanted a full tree, we would simply set cp=0.

We can see the root node error is about 48%; this indicates that initially predicting all observations as wins would give us about 52% accuracy. We can see that each CP value has a certain number of splits as well as rel error, xerror, and xstd values associated with them. rel error is the training error rate at each level. xerror is testing error rate developed by cross validation that goes on within rpart (default is 10 fold cross validation). These error numbers are in relation to the root node error; so when nsplit is 0, our rel error is 100% of the root node error. When we perform 2 splits, our rel error is 50%, or half of the root node error (i.e. around 24%).

There are two general rules of thumb when trying to prune our tree:

  • Choose the split that has the lowest cross validation error

  • Choose the split that is the lowest and within one xstd (standard deviation) from the lowest cross validation error value.

I would usually go with the second option to try and make sure our tree doesn’t overfit. In this case, the model with 4 splits has the lowest xerror, and no simpler model is within one standard deviation of this xerror value (.375 + .074), so we’ll continue without doing any pruning. If we wanted to prune back the tree, we could use the prune.rpart() function to cut the tree at a specific CP value.

Now we can easily predict on our test data set using the predict() function. We can specify type="class" to get back class predictions. If we don’t specify this, predictions will come back as predicted probabilities for each outcome.

test.predictions<- predict(simple.dt, newdata=testing, type="class")

confusionMatrix(test.predictions, testing$W.L, positive = "W")

## Confusion Matrix and Statistics
## 
##           Reference
## Prediction L W
##          L 3 1
##          W 2 6
##                                           
##                Accuracy : 0.75            
##                  95% CI : (0.4281, 0.9451)
##     No Information Rate : 0.5833          
##     P-Value [Acc > NIR] : 0.1916          
##                                           
##                   Kappa : 0.4706          
##  Mcnemar's Test P-Value : 1.0000          
##                                           
##             Sensitivity : 0.8571          
##             Specificity : 0.6000          
##          Pos Pred Value : 0.7500          
##          Neg Pred Value : 0.7500          
##              Prevalence : 0.5833          
##          Detection Rate : 0.5000          
##    Detection Prevalence : 0.6667          
##       Balanced Accuracy : 0.7286          
##                                           
##        'Positive' Class : W               
## 

We can see that our overall test set accuracy is 75%. The model predicts wins a bit better than losses, as seen in the sensitivity and specificity values, but this may be a result of the small testing dataset we’re working with in this example.

The main pro of decision trees are that they are easy to visualize and understand and can be applied to both regression and classification problems. They are, however, lazy learners. They can give good results, but a lot of the time, they don’t improve much beyond random guessing. We can improve them though, through combining several of them together. The next tutorial will go over a combining method called bagging and will take a look at the randomforest package.