Decision Trees
Decision trees are a favorite in ML interviews because they test recursive thinking, understanding of information theory, and ability to build complex data structures. We implement the full algorithm: entropy, Gini impurity, information gain, recursive splitting, pruning, and a random forest ensemble.
Splitting Criteria: Entropy and Gini
The core idea of a decision tree is to find the feature and threshold that best separates the data. We measure "best" using either entropy (information gain) or Gini impurity.
import numpy as np
def entropy(y):
"""Compute Shannon entropy of label array.
H(y) = -sum(p_k * log2(p_k)) for each class k
"""
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
# Filter out zero probabilities to avoid log(0)
probs = probs[probs > 0]
return -np.sum(probs * np.log2(probs))
def gini_impurity(y):
"""Compute Gini impurity of label array.
Gini(y) = 1 - sum(p_k^2) for each class k
"""
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1 - np.sum(probs ** 2)
def information_gain(y, y_left, y_right, criterion='entropy'):
"""Compute information gain from a split.
IG = H(parent) - weighted_avg(H(left), H(right))
"""
if criterion == 'entropy':
measure = entropy
else:
measure = gini_impurity
n = len(y)
n_left, n_right = len(y_left), len(y_right)
parent_measure = measure(y)
child_measure = (n_left / n) * measure(y_left) + \
(n_right / n) * measure(y_right)
return parent_measure - child_measure
# ---- Test ----
y_pure = np.array([1, 1, 1, 1])
y_mixed = np.array([0, 0, 1, 1])
y_skewed = np.array([0, 0, 0, 1])
print(f"Pure entropy: {entropy(y_pure):.4f}") # 0.0
print(f"Mixed entropy: {entropy(y_mixed):.4f}") # 1.0
print(f"Skewed entropy: {entropy(y_skewed):.4f}") # 0.8113
print(f"Pure Gini: {gini_impurity(y_pure):.4f}") # 0.0
print(f"Mixed Gini: {gini_impurity(y_mixed):.4f}") # 0.5
print(f"Skewed Gini: {gini_impurity(y_skewed):.4f}")# 0.375
Full Decision Tree Implementation
class DecisionTreeNode:
"""A node in the decision tree."""
def __init__(self, feature_idx=None, threshold=None,
left=None, right=None, value=None):
self.feature_idx = feature_idx # Feature index to split on
self.threshold = threshold # Threshold value for split
self.left = left # Left subtree (feature <= threshold)
self.right = right # Right subtree (feature > threshold)
self.value = value # Leaf value (class label or None)
class DecisionTreeClassifier:
"""Decision Tree Classifier from scratch."""
def __init__(self, max_depth=10, min_samples_split=2,
min_samples_leaf=1, criterion='entropy'):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.criterion = criterion
self.root = None
def fit(self, X, y):
self.n_classes = len(np.unique(y))
self.root = self._build_tree(X, y, depth=0)
return self
def _build_tree(self, X, y, depth):
n_samples, n_features = X.shape
# ---- Stopping conditions ----
# 1. All samples belong to same class
if len(np.unique(y)) == 1:
return DecisionTreeNode(value=y[0])
# 2. Max depth reached
if depth >= self.max_depth:
return DecisionTreeNode(value=self._majority_class(y))
# 3. Not enough samples to split
if n_samples < self.min_samples_split:
return DecisionTreeNode(value=self._majority_class(y))
# ---- Find best split ----
best_gain = -1
best_feature = None
best_threshold = None
for feature_idx in range(n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
# Split data
left_mask = X[:, feature_idx] <= threshold
right_mask = ~left_mask
# Check minimum leaf size
if np.sum(left_mask) < self.min_samples_leaf or \
np.sum(right_mask) < self.min_samples_leaf:
continue
# Compute information gain
gain = information_gain(
y, y[left_mask], y[right_mask], self.criterion
)
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
# No valid split found
if best_gain <= 0:
return DecisionTreeNode(value=self._majority_class(y))
# ---- Recurse ----
left_mask = X[:, best_feature] <= best_threshold
right_mask = ~left_mask
left_subtree = self._build_tree(X[left_mask], y[left_mask], depth + 1)
right_subtree = self._build_tree(X[right_mask], y[right_mask], depth + 1)
return DecisionTreeNode(
feature_idx=best_feature,
threshold=best_threshold,
left=left_subtree,
right=right_subtree
)
def _majority_class(self, y):
"""Return the most common class label."""
values, counts = np.unique(y, return_counts=True)
return values[np.argmax(counts)]
def _predict_single(self, x, node):
"""Traverse tree for a single sample."""
# Leaf node
if node.value is not None:
return node.value
# Internal node: go left or right
if x[node.feature_idx] <= node.threshold:
return self._predict_single(x, node.left)
else:
return self._predict_single(x, node.right)
def predict(self, X):
return np.array([self._predict_single(x, self.root) for x in X])
def accuracy(self, X, y):
return np.mean(self.predict(X) == y)
# ---- Test ----
np.random.seed(42)
X_c0 = np.random.randn(100, 2) + np.array([2, 2])
X_c1 = np.random.randn(100, 2) + np.array([-2, -2])
X = np.vstack([X_c0, X_c1])
y = np.array([0]*100 + [1]*100)
tree = DecisionTreeClassifier(max_depth=5, criterion='gini')
tree.fit(X, y)
print(f"Decision Tree accuracy: {tree.accuracy(X, y):.4f}")
Pre-Pruning and Post-Pruning
Pruning prevents overfitting. Pre-pruning stops tree growth early (max_depth, min_samples_split, min_samples_leaf). Post-pruning grows the full tree then removes branches that don't improve validation performance.
def post_prune(node, X_val, y_val):
"""Reduced error pruning: replace subtree with leaf if it
doesn't hurt validation accuracy."""
if node.value is not None:
return node # Already a leaf
# Recursively prune children first
if node.left:
node.left = post_prune(node.left, X_val, y_val)
if node.right:
node.right = post_prune(node.right, X_val, y_val)
# Try replacing this subtree with a leaf
# Get predictions with current subtree
tree_preds = np.array([_predict_single_node(x, node) for x in X_val])
tree_acc = np.mean(tree_preds == y_val)
# Get predictions if we replace with majority class leaf
all_leaves = _get_leaf_values(node)
majority = max(set(all_leaves), key=all_leaves.count)
leaf_acc = np.mean(np.full(len(y_val), majority) == y_val)
# Prune if leaf is at least as good
if leaf_acc >= tree_acc:
return DecisionTreeNode(value=majority)
return node
def _predict_single_node(x, node):
if node.value is not None:
return node.value
if x[node.feature_idx] <= node.threshold:
return _predict_single_node(x, node.left)
return _predict_single_node(x, node.right)
def _get_leaf_values(node):
if node.value is not None:
return [node.value]
return _get_leaf_values(node.left) + _get_leaf_values(node.right)
Random Forest
A random forest builds multiple decision trees on bootstrapped samples with random feature subsets, then aggregates their predictions by majority vote.
class RandomForest:
"""Random Forest Classifier from scratch."""
def __init__(self, n_trees=10, max_depth=10,
min_samples_split=2, max_features='sqrt'):
self.n_trees = n_trees
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.max_features = max_features
self.trees = []
self.feature_indices = []
def fit(self, X, y):
n_samples, n_features = X.shape
self.trees = []
self.feature_indices = []
# Determine number of features per tree
if self.max_features == 'sqrt':
n_sub_features = int(np.sqrt(n_features))
elif self.max_features == 'log2':
n_sub_features = int(np.log2(n_features))
else:
n_sub_features = n_features
n_sub_features = max(1, n_sub_features)
for _ in range(self.n_trees):
# Bootstrap sample (sample with replacement)
boot_idx = np.random.choice(n_samples, size=n_samples, replace=True)
X_boot = X[boot_idx]
y_boot = y[boot_idx]
# Random feature subset
feat_idx = np.random.choice(n_features, size=n_sub_features,
replace=False)
self.feature_indices.append(feat_idx)
# Train tree on bootstrap sample with feature subset
tree = DecisionTreeClassifier(
max_depth=self.max_depth,
min_samples_split=self.min_samples_split
)
tree.fit(X_boot[:, feat_idx], y_boot)
self.trees.append(tree)
return self
def predict(self, X):
"""Majority vote across all trees."""
# Collect predictions from each tree
all_preds = np.array([
tree.predict(X[:, feat_idx])
for tree, feat_idx in zip(self.trees, self.feature_indices)
]) # (n_trees, n_samples)
# Majority vote for each sample
predictions = np.zeros(X.shape[0], dtype=int)
for i in range(X.shape[0]):
values, counts = np.unique(all_preds[:, i], return_counts=True)
predictions[i] = values[np.argmax(counts)]
return predictions
def accuracy(self, X, y):
return np.mean(self.predict(X) == y)
# ---- Test ----
rf = RandomForest(n_trees=20, max_depth=5)
rf.fit(X, y)
print(f"Random Forest accuracy: {rf.accuracy(X, y):.4f}")
Key Takeaways
- Entropy:
H(y) = -sum(p_k * log2(p_k))— measures uncertainty in labels - Gini impurity:
1 - sum(p_k^2)— probability of misclassifying a random sample - Information gain:
H(parent) - weighted_avg(H(children))— improvement from split - Recursion is the natural structure: each node splits data and recurses on left/right
- Pre-pruning (max_depth, min_samples) is simpler; post-pruning requires validation data
- Random forest: bootstrap samples + random feature subsets + majority vote
Lilly Tech Systems