Implementing a CNN from Scratch: MNIST Digit Classification
In this tutorial, we’ll build a Convolutional Neural Network (CNN) from scratch to classify handwritten digits from the MNIST dataset. This will help you understand the fundamental concepts of computer vision and deep learning.
Table of Contents
Open Table of Contents
Why CNNs for Image Classification?
Traditional neural networks treat images as flat vectors, losing spatial information. CNNs preserve spatial relationships through:
- Convolution: Detecting local features
- Pooling: Reducing spatial dimensions
- Translation Invariance: Recognizing patterns regardless of position
Dataset Overview
The MNIST dataset contains 70,000 images of handwritten digits (0-9):
- Training set: 60,000 images
- Test set: 10,000 images
- Image size: 28x28 pixels (grayscale)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Load MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform
)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=1000, shuffle=False
)
Building the CNN Architecture
Our CNN will have the following layers:
- Convolutional Layer 1: 32 filters, 3x3 kernel
- ReLU Activation
- Max Pooling: 2x2
- Convolutional Layer 2: 64 filters, 3x3 kernel
- ReLU Activation
- Max Pooling: 2x2
- Flatten
- Fully Connected Layer: 128 neurons
- ReLU Activation
- Output Layer: 10 neurons (one for each digit)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# Pooling layer
self.pool = nn.MaxPool2d(2, 2)
# Fully connected layers
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
# Activation and dropout
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# First conv block
x = self.pool(self.relu(self.conv1(x)))
# Second conv block
x = self.pool(self.relu(self.conv2(x)))
# Flatten for fully connected layers
x = x.view(-1, 64 * 7 * 7)
# Fully connected layers
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
Understanding Each Component
Convolutional Layer
The convolution operation applies filters to detect features:
def visualize_conv_layer(model, layer_name, input_image):
"""Visualize feature maps from a convolutional layer"""
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
# Register hook
if layer_name == 'conv1':
model.conv1.register_forward_hook(get_activation('conv1'))
elif layer_name == 'conv2':
model.conv2.register_forward_hook(get_activation('conv2'))
# Forward pass
_ = model(input_image)
# Get feature maps
feature_maps = activation[layer_name][0] # First sample in batch
# Visualize
fig, axes = plt.subplots(4, 8, figsize=(16, 8))
for i in range(32):
ax = axes[i//8, i%8]
ax.imshow(feature_maps[i].cpu().numpy(), cmap='gray')
ax.axis('off')
plt.suptitle(f'{layer_name} Feature Maps')
plt.show()
Pooling Layer
Max pooling reduces spatial dimensions while preserving important features:
def demonstrate_pooling():
# Create a sample feature map
sample_input = torch.randn(1, 1, 8, 8)
# Apply max pooling
pool = nn.MaxPool2d(2, 2)
output = pool(sample_input)
print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {output.shape}")
# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
ax1.imshow(sample_input[0, 0].numpy(), cmap='viridis')
ax1.set_title('Before Pooling (8x8)')
ax1.axis('off')
ax2.imshow(output[0, 0].numpy(), cmap='viridis')
ax2.set_title('After Pooling (4x4)')
ax2.axis('off')
plt.tight_layout()
plt.show()
Training the Model
def train_model(model, train_loader, test_loader, epochs=10):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_losses = []
test_accuracies = []
for epoch in range(epochs):
# Training
model.train()
running_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 100 == 0:
print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
# Calculate average training loss
avg_loss = running_loss / len(train_loader)
train_losses.append(avg_loss)
# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
test_accuracies.append(accuracy)
print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
return train_losses, test_accuracies
# Initialize and train the model
model = SimpleCNN()
train_losses, test_accuracies = train_model(model, train_loader, test_loader)
Visualizing Results
def plot_training_results(train_losses, test_accuracies):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Plot training loss
ax1.plot(train_losses)
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True)
# Plot test accuracy
ax2.plot(test_accuracies)
ax2.set_title('Test Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.grid(True)
plt.tight_layout()
plt.show()
def visualize_predictions(model, test_loader, num_samples=8):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
with torch.no_grad():
data, target = next(iter(test_loader))
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output, 1)
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i in range(num_samples):
ax = axes[i//4, i%4]
ax.imshow(data[i].cpu().squeeze(), cmap='gray')
ax.set_title(f'True: {target[i].item()}, Pred: {predicted[i].item()}')
ax.axis('off')
plt.tight_layout()
plt.show()
Model Analysis
Performance Metrics
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
def detailed_evaluation(model, test_loader):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()
all_predictions = []
all_targets = []
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = torch.max(output, 1)
all_predictions.extend(predicted.cpu().numpy())
all_targets.extend(target.cpu().numpy())
# Confusion matrix
cm = confusion_matrix(all_targets, all_predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
# Classification report
print(classification_report(all_targets, all_predictions))
return cm
Key Insights
What Makes CNNs Effective?
- Local Connectivity: Each neuron connects only to a small region of the input
- Parameter Sharing: Same filter weights applied across the entire image
- Translation Invariance: Can recognize patterns regardless of position
- Hierarchical Feature Learning: Early layers detect edges, later layers detect complex patterns
Common Pitfalls and Solutions
- Overfitting: Use dropout, data augmentation, and regularization
- Vanishing Gradients: Use proper weight initialization and batch normalization
- Slow Convergence: Adjust learning rate and use learning rate scheduling
Conclusion
This tutorial demonstrated how to build a CNN from scratch for image classification. Key takeaways:
- CNNs are specifically designed for image data and preserve spatial relationships
- Convolution detects local features, pooling reduces dimensions
- Proper training techniques are crucial for good performance
- Visualization helps understand what the model learns
The complete implementation achieves ~98% accuracy on MNIST, demonstrating the power of convolutional neural networks for computer vision tasks.
For the complete code and interactive Jupyter notebook, visit the GitHub repository. Try experimenting with different architectures and hyperparameters!