This example applies R‘s decision tree tools to the iris data and does some simple visualization.
iris = read.table("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
sep = ",", header = FALSE)
names(iris) = c("sepal.length", "sepal.width", "petal.length", "petal.width",
"iris.type")
attach(iris)
The library we will use is the tree library
library(tree)
## Warning: package 'tree' was built under R version 2.15.1
stree = tree(iris.type ~ ., data = iris)
stree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
## 2) petal.length < 2.45 50 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 3) petal.length > 2.45 100 100 Iris-versicolor ( 0.00 0.50 0.50 )
## 6) petal.width < 1.75 54 30 Iris-versicolor ( 0.00 0.91 0.09 )
## 12) petal.length < 4.95 48 10 Iris-versicolor ( 0.00 0.98 0.02 )
## 24) sepal.length < 5.15 5 5 Iris-versicolor ( 0.00 0.80 0.20 ) *
## 25) sepal.length > 5.15 43 0 Iris-versicolor ( 0.00 1.00 0.00 ) *
## 13) petal.length > 4.95 6 8 Iris-virginica ( 0.00 0.33 0.67 ) *
## 7) petal.width > 1.75 46 10 Iris-virginica ( 0.00 0.02 0.98 )
## 14) petal.length < 4.95 6 5 Iris-virginica ( 0.00 0.17 0.83 ) *
## 15) petal.length > 4.95 40 0 Iris-virginica ( 0.00 0.00 1.00 ) *
We can also plot the tree. The second command adds the appropriate text to the visualization.
plot(stree)
text(stree)
The tree function uses deviance or entropy by default:
stree = tree(iris.type ~ ., data = iris, split = "gini")
stree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
## 2) petal.length < 1.35 11 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 3) petal.length > 1.35 139 300 Iris-versicolor ( 0.28 0.36 0.36 )
## 6) sepal.width < 2.35 7 6 Iris-versicolor ( 0.00 0.86 0.14 ) *
## 7) sepal.width > 2.35 132 300 Iris-virginica ( 0.30 0.33 0.37 )
## 14) sepal.width < 2.55 11 10 Iris-versicolor ( 0.00 0.64 0.36 )
## 28) petal.length < 4.25 6 0 Iris-versicolor ( 0.00 1.00 0.00 ) *
## 29) petal.length > 4.25 5 5 Iris-virginica ( 0.00 0.20 0.80 ) *
## 15) sepal.width > 2.55 121 300 Iris-virginica ( 0.32 0.31 0.37 )
## 30) petal.length < 1.45 12 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 31) petal.length > 1.45 109 200 Iris-virginica ( 0.25 0.34 0.41 )
## 62) sepal.width < 2.65 5 7 Iris-versicolor ( 0.00 0.60 0.40 ) *
## 63) sepal.width > 2.65 104 200 Iris-virginica ( 0.26 0.33 0.41 )
## 126) sepal.width < 2.75 9 10 Iris-versicolor ( 0.00 0.56 0.44 ) *
## 127) sepal.width > 2.75 95 200 Iris-virginica ( 0.28 0.31 0.41 )
## 254) sepal.width < 2.85 14 20 Iris-virginica ( 0.00 0.43 0.57 )
## 508) petal.length < 4.85 7 6 Iris-versicolor ( 0.00 0.86 0.14 ) *
## 509) petal.length > 4.85 7 0 Iris-virginica ( 0.00 0.00 1.00 ) *
## 255) sepal.width > 2.85 81 200 Iris-virginica ( 0.33 0.28 0.38 )
## 510) petal.length < 1.55 14 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 511) petal.length > 1.55 67 100 Iris-virginica ( 0.19 0.34 0.46 )
## 1022) petal.length < 5.05 38 60 Iris-versicolor ( 0.34 0.61 0.05 )
## 2044) petal.length < 2.75 13 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 2045) petal.length > 2.75 25 10 Iris-versicolor ( 0.00 0.92 0.08 )
## 4090) petal.length < 4.75 20 0 Iris-versicolor ( 0.00 1.00 0.00 ) *
## 4091) petal.length > 4.75 5 7 Iris-versicolor ( 0.00 0.60 0.40 ) *
## 1023) petal.length > 5.05 29 0 Iris-virginica ( 0.00 0.00 1.00 ) *
plot(stree)
text(stree)
We can also easily fit a decision tree using only two of the variables. This will allow us to visualize the regions in 2 dimensions fairly easily.
stree = tree(iris.type ~ petal.width + petal.length, data = iris)
stree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
## 2) petal.width < 0.8 50 0 Iris-setosa ( 1.00 0.00 0.00 ) *
## 3) petal.width > 0.8 100 100 Iris-versicolor ( 0.00 0.50 0.50 )
## 6) petal.width < 1.75 54 30 Iris-versicolor ( 0.00 0.91 0.09 )
## 12) petal.length < 4.95 48 10 Iris-versicolor ( 0.00 0.98 0.02 ) *
## 13) petal.length > 4.95 6 8 Iris-virginica ( 0.00 0.33 0.67 ) *
## 7) petal.width > 1.75 46 10 Iris-virginica ( 0.00 0.02 0.98 )
## 14) petal.length < 4.95 6 5 Iris-virginica ( 0.00 0.17 0.83 ) *
## 15) petal.length > 4.95 40 0 Iris-virginica ( 0.00 0.00 1.00 ) *
plot(stree)
text(stree)
Here is a visualization of this two-dimensional decision boundary
In[7]:
%load_ext rmagic
%R -d iris
from matplotlib import pyplot as plt, mlab
col = {1:'r', 2:'y', 3:'g'}
coln = {'Iris-setosa':'r', 'Iris-versicolor':'y', 'Iris-virginica':'g'}
In[8]:
iris.shape
plt.scatter(iris['petal.length'], iris['petal.width'], c=[col[t] for t in iris['iris.type']])
a = plt.gca()
a.set_xlabel('Petal length')
a.set_ylabel('Petal width')
# Here are the regions as described in R's plot above
# There are five terminal leaves, so there are five regions
xf, yf = mlab.poly_between([0,8],[-0.5,-0.5],[0.8,0.8])
plt.fill(xf, yf, coln['Iris-setosa'], alpha=0.3)
xf, yf = mlab.poly_between([0,4.95],[0.8,0.8],[1.75,1.75])
plt.fill(xf, yf, coln['Iris-versicolor'], alpha=0.3)
xf, yf = mlab.poly_between([4.95,8],[0.8,0.8],[1.75,1.75])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)
xf, yf = mlab.poly_between([4.95,8],[1.75,1.75],[3,3])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)
xf, yf = mlab.poly_between([0,4.95],[1.75,1.75],[3,3])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)
a.set_xlim((0,8))
a.set_ylim((-0.5,3))
Out[8]:
(-0.5, 3)
<matplotlib.figure.Figure at 0x67cfdd0>
Let’s draw the decision boundaries another way, one that will be easier for later classifiers. Recall that a classifier assigns a label to each point in the feature space. We are going to evaluate the label on a dense grid, then show an image of that label.
In[9]:
grid = np.mgrid[-0.5:3.5:500j,0:8:400j]
gridT = grid.reshape((2,-1)).T
gridT.shape
Out[9]:
(200000, 2)
Having created a grid, we will use R‘s predict function to evaluate the label on this grid.
In[10]:
%%R -i gridT -o labels
colnames(gridT) = c('petal.width', 'petal.length')
gridT = data.frame(gridT)
labels = predict(stree, gridT, type='class')
We underlay the image beneath the scatter plot.
In[11]:
plt.imshow(labels.reshape((500,400)), interpolation='nearest', origin='lower', alpha=0.4, extent=[0,8,-0.5,3], cmap=pylab.cm.RdYlGn)
plt.scatter(iris['petal.length'], iris['petal.width'], c=[col[t] for t in iris['iris.type']])
a = plt.gca()
a.set_xlim((0,8))
a.set_ylim((-0.5,3))
Out[11]:
(-0.5, 3)
<matplotlib.figure.Figure at 0x67bfb30>
Finally, let’s form a test set of size 50 and training set of size 100. We will fit the tree on the training set and then evaluate the result on the test set.
test_sample = sample(1:150, 50)
test_data = iris[test_sample, ]
training_data = iris[-test_sample, ]
fit_tree = tree(iris.type ~ ., data = training_data)
test_predictions = predict(fit_tree, test_data, type = "class")
test_error = sum(test_predictions != test_data$iris.type)/nrow(test_data)
test_error
## [1] 0.1
training_error = sum(predict(fit_tree, type = "class") != training_data$iris.type)/nrow(training_data)
training_error
## [1] 0.01
Usually, the test error is slightly higher than the training error, but this does not have to be the case.
Let’s try the tree classifier on our voting data:
votes = read.table("http://stats202.stanford.edu/data/2011_cleaned_votes.csv",
header = TRUE, sep = ";")
dim(votes)
## [1] 426 948
vtree = tree(party ~ ., data = votes)
vtree
## node), split, n, deviance, yval, (yprob)
## * denotes terminal node
##
## 1) root 426 600 R ( 0.4 0.6 )
## 2) numeric_vote675 < 0.5 189 0 D ( 1.0 0.0 ) *
## 3) numeric_vote675 > 0.5 237 0 R ( 0.0 1.0 ) *
plot(vtree)
text(vtree)
Apparently, bill 675 are very partisan, achieving perfect separation. Here is the bill. We see that all 239 Republicans voted Y, 187 Democrats voted N and 6 Democrats did not vote. What if it was 6 Republicans who did not vote? Where would the split be?
The rpart library also has a decision tree fitter. This one is more flexible and follows closer to the standard CART approach though its pruning is different than described in the notes. First, we show how to prune a tree as described in the notes.
health = read.table("http://stats202.stanford.edu/data/health.csv", sep = ",",
header = TRUE)
health_no_sample_weight = health[, -c(13)]
health_tree = tree(country ~ ., data = health_no_sample_weight, mindev = 0.001)
summary(health_tree)
##
## Classification tree:
## tree(formula = country ~ ., data = health_no_sample_weight, mindev = 0.001)
## Variables actually used in tree construction:
## [1] "age" "weight" "teeth" "vegetables"
## [5] "hungry" "hands_soap" "fruit" "height"
## [9] "hands_toilet" "bmi"
## Number of terminal nodes: 54
## Residual mean deviance: 1.19 = 25100 / 21000
## Misclassification error rate: 0.245 = 5163 / 21089
plot(health_tree)
text(health_tree)
Now, let’s prune it
plot(prune.tree(health_tree))
abline(v = 6, col = "red")
We may want to specify a particular value for the cost-complexity parameter. Above, the plot shows that for a cost-complexity parameter of about 390, we should have a tree of size 6.
pruned_tree = prune.tree(health_tree, k = 390)
summary(pruned_tree)
##
## Classification tree:
## snip.tree(tree = health_tree, nodes = c(15, 27, 26, 2, 12, 14
## ))
## Variables actually used in tree construction:
## [1] "age" "weight" "teeth" "hungry"
## Number of terminal nodes: 6
## Residual mean deviance: 1.44 = 30400 / 21100
## Misclassification error rate: 0.288 = 6077 / 21089
plot(pruned_tree)
text(pruned_tree)
Let’s fit the tree using rpart. We see that it has been pruned already. In fact, no observations are labelled ugh.
library(rpart)
health.tree = rpart(country ~ ., method = "class", data = health_no_sample_weight)
summary(health.tree)
## Call:
## rpart(formula = country ~ ., data = health_no_sample_weight,
## method = "class")
## n= 24662
##
## CP nsplit rel error xerror xstd
## 1 0.06746 0 1.0000 1.0000 0.008495
## 2 0.02728 2 0.8651 0.8752 0.008221
## 3 0.01727 3 0.8378 0.8561 0.008172
## 4 0.01000 7 0.7687 0.7830 0.007962
##
## Node number 1: 24662 observations, complexity param=0.06746
## predicted class=aeh expected loss=0.3597
## class counts: 15790 5657 3215
## probabilities: 0.640 0.229 0.130
## left son=2 (19131 obs) right son=3 (5531 obs)
## Primary splits:
## teeth splits as LLLLRL, improve=909.3, (211 missing)
## age splits as LLLRRR, improve=866.2, (314 missing)
## weight < 55.5 to the right, improve=611.4, (2187 missing)
## hungry splits as LLLRR, improve=595.0, (269 missing)
## bmi < 22.84 to the right, improve=515.6, (2187 missing)
##
## Node number 2: 19131 observations, complexity param=0.01727
## predicted class=aeh expected loss=0.2829
## class counts: 13719 3065 2347
## probabilities: 0.717 0.160 0.123
## left son=4 (5837 obs) right son=5 (13294 obs)
## Primary splits:
## age splits as LLLRRR, improve=449.5, (246 missing)
## hungry splits as LLLRR, improve=342.6, (232 missing)
## hands_soap splits as LLRRR, improve=324.7, (237 missing)
## weight < 56.5 to the right, improve=306.2, (1675 missing)
## bmi < 22.32 to the right, improve=300.3, (1675 missing)
##
## Node number 3: 5531 observations, complexity param=0.06746
## predicted class=pih expected loss=0.5314
## class counts: 2071 2592 868
## probabilities: 0.374 0.469 0.157
## left son=6 (1200 obs) right son=7 (4331 obs)
## Primary splits:
## age splits as LLLRRR, improve=353.0, (68 missing)
## weight < 50.5 to the right, improve=260.5, (512 missing)
## hungry splits as LRLRR, improve=177.3, (37 missing)
## bmi < 23.14 to the right, improve=146.8, (512 missing)
## height < 1.575 to the left, improve=133.0, (512 missing)
##
## Node number 4: 5837 observations
## predicted class=aeh expected loss=0.09474
## class counts: 5284 313 240
## probabilities: 0.905 0.054 0.041
##
## Node number 5: 13294 observations, complexity param=0.01727
## predicted class=aeh expected loss=0.3655
## class counts: 8435 2752 2107
## probabilities: 0.634 0.207 0.158
## left son=10 (4675 obs) right son=11 (8619 obs)
## Primary splits:
## weight < 55.5 to the right, improve=447.2, (1185 missing)
## bmi < 23.21 to the right, improve=343.3, (1185 missing)
## hungry splits as LLLRR, improve=322.4, (155 missing)
## hands_soap splits as LLRRR, improve=311.1, (160 missing)
## vegetables splits as LRRRRRL, improve=251.1, (124 missing)
## Surrogate splits:
## bmi < 21.51 to the right, agree=0.858, adj=0.632, (0 split)
## height < 1.665 to the right, agree=0.697, adj=0.216, (0 split)
##
## Node number 6: 1200 observations
## predicted class=aeh expected loss=0.25
## class counts: 900 224 76
## probabilities: 0.750 0.187 0.063
##
## Node number 7: 4331 observations, complexity param=0.02728
## predicted class=pih expected loss=0.4532
## class counts: 1171 2368 792
## probabilities: 0.270 0.547 0.183
## left son=14 (1574 obs) right son=15 (2757 obs)
## Primary splits:
## weight < 50.5 to the right, improve=263.9, (399 missing)
## height < 1.575 to the left, improve=190.2, (399 missing)
## hungry splits as LRLRR, improve=124.5, (23 missing)
## bmi < 23.14 to the right, improve=115.5, (399 missing)
## vegetables splits as LLRRRLL, improve=101.2, (35 missing)
## Surrogate splits:
## bmi < 20.55 to the right, agree=0.801, adj=0.503, (0 split)
## height < 1.615 to the right, agree=0.731, adj=0.328, (0 split)
##
## Node number 10: 4675 observations
## predicted class=aeh expected loss=0.1765
## class counts: 3850 276 549
## probabilities: 0.824 0.059 0.117
##
## Node number 11: 8619 observations, complexity param=0.01727
## predicted class=aeh expected loss=0.468
## class counts: 4585 2476 1558
## probabilities: 0.532 0.287 0.181
## left son=22 (4832 obs) right son=23 (3787 obs)
## Primary splits:
## hungry splits as LLLRR, improve=265.7, (111 missing)
## vegetables splits as LRRRRRR, improve=226.7, (79 missing)
## teeth splits as LLLR-R, improve=204.6, (101 missing)
## hands_soap splits as LLRRR, improve=189.8, (104 missing)
## fruit splits as LRRRRRR, improve=181.0, (110 missing)
## Surrogate splits:
## hands_soap splits as LLLLR, agree=0.565, adj=0.020, (101 split)
## hands_toilet splits as LRLRR, agree=0.563, adj=0.014, (5 split)
##
## Node number 14: 1574 observations
## predicted class=aeh expected loss=0.5623
## class counts: 689 447 438
## probabilities: 0.438 0.284 0.278
##
## Node number 15: 2757 observations
## predicted class=pih expected loss=0.3032
## class counts: 482 1921 354
## probabilities: 0.175 0.697 0.128
##
## Node number 22: 4832 observations
## predicted class=aeh expected loss=0.3551
## class counts: 3116 857 859
## probabilities: 0.645 0.177 0.178
##
## Node number 23: 3787 observations, complexity param=0.01727
## predicted class=pih expected loss=0.5725
## class counts: 1469 1619 699
## probabilities: 0.388 0.428 0.185
## left son=46 (1744 obs) right son=47 (2043 obs)
## Primary splits:
## vegetables splits as LLRRRLL, improve=150.70, (38 missing)
## teeth splits as LLLR-R, improve=135.50, (43 missing)
## fruit splits as LLRRRRL, improve= 87.00, (49 missing)
## height < 1.525 to the left, improve= 85.22, (491 missing)
## hands_soap splits as LLRRR, improve= 76.15, (38 missing)
## Surrogate splits:
## fruit splits as LLRRRRL, agree=0.652, adj=0.243, (31 split)
## teeth splits as LLLR-R, agree=0.583, adj=0.092, (5 split)
## hands_soap splits as RRLRR, agree=0.544, adj=0.009, (1 split)
## hands_toilet splits as RRLLR, agree=0.541, adj=0.001, (1 split)
##
## Node number 46: 1744 observations
## predicted class=aeh expected loss=0.4633
## class counts: 936 473 335
## probabilities: 0.537 0.271 0.192
##
## Node number 47: 2043 observations
## predicted class=pih expected loss=0.4391
## class counts: 533 1146 364
## probabilities: 0.261 0.561 0.178
plot(health.tree)
text(health.tree)
The parameter cp controls this pruning and has default. The algorithm terminates if the improvement in a given split is not at least cp.
health.tree = rpart(country ~ ., method = "class", data = health_no_sample_weight,
cp = 0.005)
plot(health.tree)
text(health.tree)
For rpart we can specify a loss matrix to emphasize one mistake as worse than others.
loss_matrix = matrix(c(0, 1, 1, 1, 0, 1, 1, 50, 0), 3, 3)
loss_matrix
## [,1] [,2] [,3]
## [1,] 0 1 1
## [2,] 1 0 50
## [3,] 1 1 0
health.tree.cost = rpart(country ~ ., method = "class", data = health_no_sample_weight,
parms = list(loss = loss_matrix), cp = 0.001)
plot(health.tree.cost)
text(health.tree.cost)
health.tree.cost.info = rpart(country ~ ., method = "class", data = health_no_sample_weight,
parms = list(loss = loss_matrix, split = "info"), cp = 0.001)
plot(health.tree.cost.info)
text(health.tree.cost.info)
table(predict(health.tree.cost, type = "class"), health_no_sample_weight$country)
##
## aeh pih ugh
## aeh 14694 3758 1542
## pih 146 858 0
## ugh 950 1041 1673
Although rpart accepts split as an argument it does not seem to be using it.
table(predict(health.tree.cost.info, type = "class"), health_no_sample_weight$country)
##
## aeh pih ugh
## aeh 14765 3402 1628
## pih 186 1094 0
## ugh 839 1161 1587
Let’s repeat what we had done above for the iris data, to see what the training and test error look like.
A more complicated tree may fit better (or worse).
Let’s make a plot of training error and test error