Decision Trees

Decision Trees (CART)

Tree based methods partition the feature space into a set of rectangles, and then fit a simple model (i.e constant) in each one. We focus on CART in this post.

Suppose we have dataset \(D = \{(\mathbf{x}_1, y_1) , ...., (\mathbf{x}_N, y_N) ;\; \mathbf{x}_i \in \mathbb{R}^d\}\). The algorithm needs to automatically decide on the splitting variables and splitting points and also what shape the tree should have.

Regression Trees

In this scenario, our response variable \(Y\) is continuous. Suppose first that we have a partition into \(M\) regions \(R_1, ...., R_M\) and we define the model prediction as:

\[\hat{y} = \sum^{M}_{m=1} c_m I[\mathbf{x}_i \in R_m]\]

By minimizing the mean square loss \(\frac{1}{2} \frac{1}{N} \sum^{N}_{i=1} (y_i - \hat{y}_i)^2\), we have:

\[\begin{aligned} \frac{\partial L}{\partial c_m} &= \frac{1}{N}\sum^{N}_{i=1} (y_i - \sum^{M}_{m=1} c_m I[\mathbf{x}_i \in R_m]) I[\mathbf{x}_i \in R_m]\\ &= \frac{1}{N_m}\sum^{N}_{i=1} (y_i I[\mathbf{x}_i \in R_m]) - c_m\\ \implies \hat{c}_m &= \frac{1}{N_m}\sum^{N}_{i=1} (y_i I[\mathbf{x}_i \in R_m]) \end{aligned}\]

Thus, the best estimate \(\hat{c}_m\) in each region is the average training responses in that region w.r.t mean square error:

\[\hat{c}_m = \frac{1}{N_m} \sum^{N}_{i=1} y_i I[\mathbf{x}_i \in R_m]\]

Where \(N_m = \sum^{N}_{i=1} I[\mathbf{x}_i \in R_m]\), is total training examples in region \(R_m\).

Best Splitting Point

Now, finding the binary partition in terms of minimum sum of squares is generally computational infeasible. However, we can use a greedy algorithm that starts with all of the data, considering every splitting variable \(j\) and splitting point \(s\) and find the pair that minimize the particular loss:

  • Every pair of splitting point \(j\) and split point \(s\) define the pair of half-planes:

    • \(R_1(j, s) = \{\mathbf{x} | x_j \leq s\}\) (all samples that have \(j\)th feature less than or equal to \(s\))
    • \(R_2(j, s) = \{\mathbf{x} | x_j > s\}\) (all samples that have \(j\)th feature greater than \(s\))
  • For each half-plane, we find the best estimate that will minimize the mean square loss in that region: \[\hat{c}_1 = \underset{c_1}{\arg\min} \frac{1}{N_{1}}\sum_{x_i \in R_1 (j, s)} (y_i - c_1)^2\] \[\hat{c}_2 = \underset{c_2}{\arg\min} \frac{1}{N_{2}}\sum_{x_i \in R_2 (j, s)} (y_i - c_2)^2\]

    From previous results, we know that the minimizers are the average training responses in these regions, that is:

    \[\hat{c}_1 = \frac{1}{N_1} \sum_{x_i \in R_1 (j, s)} y_i I[\mathbf{x}_i \in R_1] \quad \quad \hat{c}_2 = \frac{1}{N_2} \sum_{x_2 \in R_2 (j, s)} y_i I[\mathbf{x}_i \in R_2]\]

  • For any choice of \((j, s)\), we seek to minimize the overall objective:

    \[\min_{j, s} [\frac{1}{N_{1}}\sum_{x_i \in R_1 (j, s)} (y_i - \hat{c}_1)^2 + \frac{1}{N_{2}}\sum_{x_i \in R_2 (j, s)} (y_i - \hat{c}_2)^2]\]

  • This optimization problem can be solved by scanning through all positive pair of \((j, s)\) very quickly. Having found the best split, we partition the data into two regions and repeat this finding procedure until stopping signal received.


Cost Complexity Pruning

How large should the tree grow? Clearly, a large tree will overfit the training set while a small tree might not capture the important structure. Tree size is a tuning parameter governing the model's complexity, and the optimal tree size should be adaptively chosen from the data.

The preferred strategy is to grow a large tree \(T_0\), stopping the splitting process only until some stopping signals, then pruned the large tree using cost-complexity pruning.

We define a subtree \(T \in T_0\) to be any tree that can be obtained by pruning \(T_0\), that is, collapsing any number of its internal (non-leave) nodes (set the nodes as leaves). The idea is to find, for each \(\alpha\) the subtree \(T_\alpha \subset T_0\) to minimize the objective:

\[C_{\alpha} (T) = \sum^{|T|}_{m=1} \sum_{x_i \in R_m} (y_i - \hat{c}_m)^2 + \alpha |T|\]

Where \(|T|\) represents current number of leaves, \(\alpha |T|\) is the penalty term that trades off tree size and goodness of fit.

Hyperparameters:

  1. Tree size: \(\;\) Governing the tree's complexity.
  2. Minimum decease in loss from split: \(\;\)Split tree nodes only if the decrease in sum of squares due to the split exceeds some threshold. However, this approach is short-sighted because a seemingly worthless split might lead to a very good split below it.
  3. Minimum or maximum leave size: \(\;\) Minimum or maximum number of leaves.
  4. \(\alpha\): Controls for penalty in cost-complexity pruning.

Classification Trees

If \(Y\) is a classification outcome taking values \(1, 2, 3, ...., K\), the only change needed in the tree algorithm pertain to the criteria for splitting nodes and pruning the tree. For regression, we used MSE as the splitting criteria and we use average response \(\hat{c}\) at each leave as our prediction. In classification case, for each region \(m\), we use proportion to make predictions:

\[\hat{p}_{mk} = \frac{1}{N_m} \sum_{x_1 \in R_m} I[y_i = k]\]

\[\hat{y}_m = \underset{k}{\arg\max} \; \hat{p}_{mk} \]

Thus, the prediction at each region is the majority class in that region.

Impurity Measure

In classification case, we call the splitting criteria impurity measure. We have several choices for the impurity measure:

  1. Misclassification Error: \[\frac{1}{N_m} \sum_{i \in R_m} I[y_i \neq \hat{y}_m] = 1 - \hat{p}_{m \hat{y}_m}\]

  2. Gini Index: \[\sum_{k \neq k^{\prime}} \hat{p}_{mk} \hat{p}_{mk^{\prime}} = \sum^{K}_{k=1} \hat{p}_{mk} (1 - \hat{p}_{mk})\]

    Notice here, if there is only one class in the region, then the gini index will be \(0\). That is, gini index prefers purer nodes.

  3. Cross-entropy: \[-\sum^{K}_{k=1} \hat{p}_{mk} \log \hat{p}_{mk}\]

    Notice here, if there is only one class in the region, then the cross-entropy will be \(0\) which is minimum. Thus, cross-entropy prefers purer nodes.

In general, we should always use cross entropy or gini index over misclassification rate, because misclassification rate does not capture extra purity. The impurity needs to scale by the instance number in each region

Other Issues

Instability of Trees

One major problem with trees is their high variance, often a small change in the data changes the structure of the tree. The major reason for this instability is the hierarchical nature of the process: the effect of an error in the top split is propagated down to all of the splits below it.

Difficulty in Capturing Additive Structure

Another problem with trees is their difficulty in modeling additive structure. For regression \(Y = c_1 I[X_1 < t_1] + c_2 I[X_2 < t_2] + \epsilon\), a tree has to split based on \((X_1, t_1)\), then split on \((X_2, t_2)\). This might happen with sufficient data, but the model is given no special encouragement to find such structure.

Linear Combination Splits

To solve the additive structure problem, rather than restricting splits to be of the form \(X_j \leq s\), one can allow splits along linear combinations of features:

\[\sum_{j} a_j X_j \leq s\]

While this can improve the predictive power of the tree, it can hurt interpretability.

Implementation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import numpy as np
import pandas as pd
from collections import Counter

class DecisionTreeNode:

def __init__(self, split_point=None, split_feature=None, gini_index=None,
num_instance=None, left_child=None, right_child=None, prediction=None):

self.split_point = split_point
self.split_feature = split_feature
self.gini_index = gini_index
self.num_instance = num_instance
self.left_child = left_child
self.right_child = right_child
self.prediction = prediction


class DecisionTreeClassifier:

def __init__(self, max_depth=10, criterion='gini_index', min_sample_leave=1):
self.max_depth = max_depth
self.criterion = criterion
self.min_sample_splits = min_sample_leave
self._tree = None
self._num_features = None
self.feature_names = None
self._num_samples = None
self._classes = None

def fit(self, x_train, y_train):
self._num_samples, self._num_features = x_train.shape
self._classes = np.unique(y_train)
if isinstance(x_train, pd.DataFrame):
self.feature_names = x_train.columns
x_train = x_train.values
else:
self.feature_names = range(self._num_features)

self._tree = self._grow_tree(np.column_stack([x_train, y_train]))

return self

def _grow_tree(self, train, curr_depth=0, stop=False):
if stop:
return None
else:
impurity_dict = {}
for j in range(self._num_features):
curr_feature_col = train[:, j]
for s in np.unique(curr_feature_col):
impurity = self._cal_impurity(train, j, s)
if impurity in impurity_dict.keys():
impurity_dict[impurity].append((j, s))
else:
impurity_dict[impurity] = [(j, s)]

min_impurity = min(impurity_dict.keys())
j_hat, s_hat = impurity_dict[min_impurity][0]
prediction = Counter(train[:, -1]).most_common()[0][0]
R_l = train[train[:, j_hat] < s_hat]
R_r = train[train[:, j_hat] >= s_hat]
sample_split = [len(R_l), len(R_r)]

if_stop = self._check_stopping_criterion(curr_depth, sample_split)

print(curr_depth, min_impurity, sample_split, (j_hat, s_hat))

return DecisionTreeNode(split_point=s_hat,
split_feature=j_hat,
prediction=prediction,
gini_index=min_impurity,
num_instance=len(train),
left_child=self._grow_tree(train=R_l,
curr_depth=curr_depth+1,
stop=if_stop),

right_child=self._grow_tree(train=R_r,
curr_depth=curr_depth+1,
stop=if_stop))

def _check_stopping_criterion(self, curr_depth, sample_split):
if curr_depth > self.max_depth:
return True

if any([sample_split[0] < self.min_sample_splits, sample_split[1] < self.min_sample_splits]):
return True

return False

def _cal_impurity(self, train, j, s):
R_l = train[train[:, j] < s]
R_r = train[train[:, j] >= s]
R_l_y = R_l[:, -1]
R_r_y = R_r[:, -1]
N_l = len(R_l) + 1e-10
N_r = len(R_r) + 1e-10

if self.criterion == 'gini_index':
gini_l = 0
gini_r = 0
for k in self._classes:
p_l = len(R_l_y[R_l_y == k]) / N_l
p_r = len(R_r_y[R_r_y == k]) / N_r
gini_l += p_l * (1 - p_l)
gini_r += p_r * (1 - p_r)

return gini_l * N_l + gini_r * N_r

def predict(self, x_test):
output = []
for i in x_test:
output.append(self._traverse_tree(i))

return output

def _traverse_tree(self, x, tree=None):
if not tree:
tree = self._tree

if not tree.left_child and not tree.left_child:
return tree.prediction
else:
if x[tree.split_feature] < tree.split_point:
return self._traverse_tree(x, tree.left_child)
else:
return self._traverse_tree(x, tree.right_child)

Ref

ESLII Chapter 9