Skip to content

Implementing a CNN from Scratch: MNIST Digit Classification

Implementing a CNN from Scratch: MNIST Digit Classification

In this tutorial, we’ll build a Convolutional Neural Network (CNN) from scratch to classify handwritten digits from the MNIST dataset. This will help you understand the fundamental concepts of computer vision and deep learning.

Table of Contents

Open Table of Contents

Why CNNs for Image Classification?

Traditional neural networks treat images as flat vectors, losing spatial information. CNNs preserve spatial relationships through:

  • Convolution: Detecting local features
  • Pooling: Reducing spatial dimensions
  • Translation Invariance: Recognizing patterns regardless of position

Dataset Overview

The MNIST dataset contains 70,000 images of handwritten digits (0-9):

  • Training set: 60,000 images
  • Test set: 10,000 images
  • Image size: 28x28 pixels (grayscale)
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1000, shuffle=False
)

Building the CNN Architecture

Our CNN will have the following layers:

  1. Convolutional Layer 1: 32 filters, 3x3 kernel
  2. ReLU Activation
  3. Max Pooling: 2x2
  4. Convolutional Layer 2: 64 filters, 3x3 kernel
  5. ReLU Activation
  6. Max Pooling: 2x2
  7. Flatten
  8. Fully Connected Layer: 128 neurons
  9. ReLU Activation
  10. Output Layer: 10 neurons (one for each digit)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        
        # Activation and dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # First conv block
        x = self.pool(self.relu(self.conv1(x)))
        
        # Second conv block
        x = self.pool(self.relu(self.conv2(x)))
        
        # Flatten for fully connected layers
        x = x.view(-1, 64 * 7 * 7)
        
        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

Understanding Each Component

Convolutional Layer

The convolution operation applies filters to detect features:

def visualize_conv_layer(model, layer_name, input_image):
    """Visualize feature maps from a convolutional layer"""
    activation = {}
    
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook
    
    # Register hook
    if layer_name == 'conv1':
        model.conv1.register_forward_hook(get_activation('conv1'))
    elif layer_name == 'conv2':
        model.conv2.register_forward_hook(get_activation('conv2'))
    
    # Forward pass
    _ = model(input_image)
    
    # Get feature maps
    feature_maps = activation[layer_name][0]  # First sample in batch
    
    # Visualize
    fig, axes = plt.subplots(4, 8, figsize=(16, 8))
    for i in range(32):
        ax = axes[i//8, i%8]
        ax.imshow(feature_maps[i].cpu().numpy(), cmap='gray')
        ax.axis('off')
    plt.suptitle(f'{layer_name} Feature Maps')
    plt.show()

Pooling Layer

Max pooling reduces spatial dimensions while preserving important features:

def demonstrate_pooling():
    # Create a sample feature map
    sample_input = torch.randn(1, 1, 8, 8)
    
    # Apply max pooling
    pool = nn.MaxPool2d(2, 2)
    output = pool(sample_input)
    
    print(f"Input shape: {sample_input.shape}")
    print(f"Output shape: {output.shape}")
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
    
    ax1.imshow(sample_input[0, 0].numpy(), cmap='viridis')
    ax1.set_title('Before Pooling (8x8)')
    ax1.axis('off')
    
    ax2.imshow(output[0, 0].numpy(), cmap='viridis')
    ax2.set_title('After Pooling (4x4)')
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()

Training the Model

def train_model(model, train_loader, test_loader, epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    train_losses = []
    test_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        running_loss = 0.0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Calculate average training loss
        avg_loss = running_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        # Evaluation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        accuracy = 100 * correct / total
        test_accuracies.append(accuracy)
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')
    
    return train_losses, test_accuracies

# Initialize and train the model
model = SimpleCNN()
train_losses, test_accuracies = train_model(model, train_loader, test_loader)

Visualizing Results

def plot_training_results(train_losses, test_accuracies):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Plot training loss
    ax1.plot(train_losses)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.grid(True)
    
    # Plot test accuracy
    ax2.plot(test_accuracies)
    ax2.set_title('Test Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

def visualize_predictions(model, test_loader, num_samples=8):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    with torch.no_grad():
        data, target = next(iter(test_loader))
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        
        fig, axes = plt.subplots(2, 4, figsize=(12, 6))
        for i in range(num_samples):
            ax = axes[i//4, i%4]
            ax.imshow(data[i].cpu().squeeze(), cmap='gray')
            ax.set_title(f'True: {target[i].item()}, Pred: {predicted[i].item()}')
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()

Model Analysis

Performance Metrics

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def detailed_evaluation(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()
    
    # Classification report
    print(classification_report(all_targets, all_predictions))
    
    return cm

Key Insights

What Makes CNNs Effective?

  1. Local Connectivity: Each neuron connects only to a small region of the input
  2. Parameter Sharing: Same filter weights applied across the entire image
  3. Translation Invariance: Can recognize patterns regardless of position
  4. Hierarchical Feature Learning: Early layers detect edges, later layers detect complex patterns

Common Pitfalls and Solutions

  1. Overfitting: Use dropout, data augmentation, and regularization
  2. Vanishing Gradients: Use proper weight initialization and batch normalization
  3. Slow Convergence: Adjust learning rate and use learning rate scheduling

Conclusion

This tutorial demonstrated how to build a CNN from scratch for image classification. Key takeaways:

  • CNNs are specifically designed for image data and preserve spatial relationships
  • Convolution detects local features, pooling reduces dimensions
  • Proper training techniques are crucial for good performance
  • Visualization helps understand what the model learns

The complete implementation achieves ~98% accuracy on MNIST, demonstrating the power of convolutional neural networks for computer vision tasks.


For the complete code and interactive Jupyter notebook, visit the GitHub repository. Try experimenting with different architectures and hyperparameters!