Decision Trees
What are Decision Trees?
Decision trees are a machine learning model used for data classification. Unlike most modern machine learning techniques like neural networks that are complex and often act like a black box, decision trees rely on simple decision rules based on a series of yes or no questions that are easy to understand and interpret.
Decision trees are a supervised learning algorithm - they require a training dataset with correct labels to learn from. Once trained, they can be used to make predictions on the label of new data.
How Decision Trees Work
A decision tree is simply a series of yes or no questions that lead to a conclusion. Given a piece of data:
- The tree checks whether a feature meets a certain condition
- Based on the answer (yes/no), it follows a branch
- This process repeats until reaching a final prediction (leaf node)
Example: Classifying Pokémon Types
Consider classifying Pokémon as either “grass” or “electric” type based on their stats:
- Input: Pokémon stats (HP, Attack, Defense, Speed, etc.)
- Output: Predicted type (grass or electric)
A simple tree might ask: “Is speed less than 85.5?”
- If yes → predict grass
- If no → predict electric
Training a Decision Tree
A decision tree is trained by repeatedly splitting the training set into smaller groups based on feature values. At each step, the algorithm examines all possible ways to divide the data and chooses the split that best separates the classes—creating groups that are as “pure” as possible, meaning each group ideally contains mostly one class. This process continues recursively on each resulting subset until the tree reaches a stopping point, such as when nodes are completely pure or a maximum depth is reached.
Purity
Purity measures how mixed or heterogeneous the classes are within a node.
- A pure node contains only one class (impurity = 0)
- An impure node contains a mix of classes (higher impurity)
Example:
A node with 100 samples:
- 100 cats, 0 dogs → pure (impurity = 0)
- 50 cats, 50 dogs → maximally impure (it’s a coin flip)
- 90 cats, 10 dogs → low impurity (mostly one class)
The goal of splitting is to create child nodes that are more pure than the parent—ideally pushing toward nodes that contain only one class, making classification straightforward.
Training Process
1. For each feature, consider all possible split points
For a continuous feature, the algorithm typically sorts the values and considers cutoffs between adjacent unique values (often the midpoint). For a feature with values [1, 3, 7, 10], it might evaluate splits at 2, 5, and 8.5.
2. Calculate the impurity reduction for each split
For classification, the two common metrics are:
-
Gini impurity: Measures how often a randomly chosen element would be incorrectly classified. For a node:
Gini = 1 - Σ(pᵢ²)where pᵢ is the proportion of class i. -
Entropy / Information Gain: Measures the reduction in uncertainty.
Entropy = -Σ(pᵢ log₂ pᵢ)
The algorithm calculates the weighted average impurity of the child nodes after the split and compares it to the parent. The gain is:
Gain = Impurity(parent) - [weighted avg of Impurity(children)]
3. Pick the split with the highest gain
The algorithm greedily selects whichever (feature, threshold) pair produces the largest reduction in impurity.
Quick example:
If you’re splitting on “age > 30” and it results in one node that’s 90% class A and another that’s 85% class B, that’s a good split—both children are more “pure” than the parent was.
In scikit-learn, you can control this with the criterion parameter ('gini' or 'entropy' for classifiers, 'squared_error' for regressors).
Gini Impurity
The algorithm uses Gini impurity to measure how mixed a node is:
Where is the proportion of samples belonging to class at a given node.
- Gini = 0: The node is pure (all samples belong to one class)
- Gini = 0.5: Maximum impurity for binary classification (50/50 split)
The algorithm chooses splits that minimise Gini impurity, creating purer child nodes.
Accuracy Calculation
Tree Depth
The depth of a decision tree controls how many features it examines before reaching a conclusion, or how many yes/no questions it asks before making a prediction.
- Depth 1: Single question, two possible outcomes
- Depth 2: Up to two questions, more nuanced predictions
- Increasing depth: Generally increases training accuracy
Implementation in Python
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Load data
pokemon = pd.read_csv('pokemon.csv')
# Filter to grass and electric types
data = pokemon[pokemon['Type 1'].isin(['Grass', 'Electric'])]
# Define features and labels
features = data[['HP', 'Attack', 'Defense', 'Speed']]
labels = data['Type 1']
# Train the decision tree
tree = DecisionTreeClassifier(max_depth=2)
tree.fit(features, labels)
# Visualise the tree
plot_tree(tree, feature_names=features.columns, class_names=['Electric', 'Grass'])
# Make predictions
predictions = tree.predict(features)
# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
print(f"Accuracy: {accuracy:.2%}")
Overfitting
Overfitting occurs when a model memorises the training data instead of learning the underlying patterns. This is problematic because it hurts generalisation - the model performs poorly on new data it hasn’t seen before.
Signs of Overfitting
- Training accuracy keeps increasing with depth
- Test accuracy plateaus or decreases
- Large gap between training and test accuracy
The Bias-Variance Trade-off
There is a well-known trade-off in machine learning between:
- Model complexity: How well the model fits the training data
- Generalisation: How well the model performs on new, unseen data
Train-Test Split
To detect overfitting, split your data:
- Training set (80%): Used to train the model
- Test set (20%): Used to evaluate generalisation
# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
features, labels, test_size=0.2, random_state=42
)
# Train on training data
tree.fit(X_train, y_train)
# Evaluate on both sets
train_accuracy = accuracy_score(y_train, tree.predict(X_train))
test_accuracy = accuracy_score(y_test, tree.predict(X_test))
Techniques to Prevent Overfitting
1. Limiting Tree Depth
Control the max_depth parameter to prevent the tree from becoming too complex:
tree = DecisionTreeClassifier(max_depth=3)
2. Pruning
A post-processing step that:
- Identifies parts of the tree that don’t contribute much to decisions
- Removes (prunes) those branches
- Results in a simpler, more generalisable tree
3. Random Forests
Instead of learning one large tree:
- Learn multiple smaller trees
- Make predictions based on majority vote
- More robust and less prone to overfitting
Key Takeaways
- Decision trees are interpretable and easy to understand
- They work by asking a series of yes/no questions about features
- The
fitmethod automatically selects features and cutoff values - Deeper trees have higher training accuracy but risk overfitting
- Use train-test splits to monitor generalisation
- Techniques like limiting depth, pruning, and random forests help prevent overfitting
Related
- Pokémon Decision Tree Notebook - Hands-on implementation with the Pokémon dataset