- Published at
Building a CIFAR-10 CNN with PyTorch Lightning: A Practical Guide
This tutorial demonstrates building and training a Convolutional Neural Network (CNN) for the CIFAR-10 dataset using PyTorch Lightning. It covers data loading, model definition, training loop setup, TensorBoard integration, and early stopping.
- Authors
-
-
- Name
- James Lau
- Indie App Developer at Self-employed
-
Table of Contents
- Building a CIFAR-10 CNN with PyTorch Lightning: A Practical Guide
- Introduction
- Prerequisites
- Code Implementation
- 1. Imports and Setup
- 2. Model Definition (CIFAR10CNN)
- 3. Training, Validation, and Testing Steps
- 4. Optimizer and Learning Rate Scheduler
- 5. On Validation Epoch End (Confusion Matrix)
- 6. Data Loading and Preprocessing
- 7. Training Loop with PyTorch Lightning
- Running the Code
Building a CIFAR-10 CNN with PyTorch Lightning: A Practical Guide
Introduction
This blog post provides a practical guide to building and training a Convolutional Neural Network (CNN) for the CIFAR-10 dataset using PyTorch Lightning. We’ll cover everything from data loading and model definition to setting up the training loop, integrating TensorBoard for visualization, and implementing early stopping to prevent overfitting.
Prerequisites
Before we begin, make sure you have the following libraries installed:
torch: PyTorch deep learning frameworktorchvision: Datasets, transforms, and models for computer visionlightning: PyTorch Lightning for simplified trainingtensorboardx: TensorBoard integrationscikit-learn: For confusion matrix calculationseaborn: For visualizing the confusion matrixmatplotlib: For plotting
You can install them using pip:
pip install torch torchvision lightning scikit-learn seaborn matplotlib tensorboardX
Code Implementation
Let’s dive into the code. We’ll break it down step by step.
1. Imports and Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
# pip install "tensorboardX" "tensorboard" lightning scikit-learn seaborn torchvision
L.seed_everything(1121218)
This section imports all necessary libraries and sets a seed for reproducibility.
2. Model Definition (CIFAR10CNN)
The CIFAR10CNN class defines our CNN architecture:
class CIFAR10CNN(L.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
self.validation_step_outputs = []
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 64 * 4 * 4)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
It consists of three convolutional layers (conv1, conv2, conv3), max-pooling layers (pool), and two fully connected layers (fc1, fc2). The forward method defines the forward pass through the network.
3. Training, Validation, and Testing Steps
The training_step, validation_step, and test_step methods define how data is processed during training, validation, and testing respectively:
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
# Log the loss at each training step and epoch, create a progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log("train_acc", acc)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
pred = y_hat.argmax(1)
acc = (pred == y).float().mean()
self.validation_step_outputs.append((pred, y))
self.log('val_loss', loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = (y_hat.argmax(1) == y).float().mean()
self.log('test_loss', loss)
self.log('test_acc', acc)
The training_step calculates the cross-entropy loss and accuracy for each batch during training. The validation_step does the same for validation, storing predictions and labels for confusion matrix calculation later. The test_step computes loss and accuracy on the test set.
4. Optimizer and Learning Rate Scheduler
The configure_optimizers method configures the optimizer and learning rate scheduler:
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=5
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"monitor": "val_loss",
},
}
A Adam optimizer is used with a learning rate of 0.001, and a ReduceLROnPlateau scheduler reduces the learning rate when the validation loss plateaus.
5. On Validation Epoch End (Confusion Matrix)
The on_validation_epoch_end method calculates and logs the confusion matrix at the end of each validation epoch:
def on_validation_epoch_end(self):
all_preds = torch.cat([x[0] for x in self.validation_step_outputs])
all_labels = torch.cat([x[1] for x in self.validation_step_outputs])
conf_matrix = confusion_matrix(all_labels.cpu().numpy(), all_preds.cpu().numpy())
fig, ax = plt.subplots(figsize=(10, 10))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
ax.set_title('Confusion Matrix')
# Log the confusion matrix
self.logger.experiment.add_figure("Confusion Matrix", fig, self.current_epoch)
plt.close(fig) # Close the figure
self.validation_step_outputs.clear() # free memory
It aggregates predictions and labels from the validation steps, calculates the confusion matrix using sklearn.metrics, visualizes it with Seaborn, and logs it to TensorBoard.
6. Data Loading and Preprocessing
The data loading and preprocessing steps are standard for CIFAR-10:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_train
)
test_dataset = datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_test
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=8
)
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=8
)
Data augmentation (random crop and horizontal flip) is applied to the training set. Both sets are normalized.
7. Training Loop with PyTorch Lightning
The main function orchestrates the entire training process using PyTorch Lightning:
def main():
# Initialize the model
model = CIFAR10CNN()
# Initialize the Trainer
trainer = L.Trainer(
max_epochs=num_epochs,
callbacks=[checkpoint_callback, early_stopping],
logger=logger,
accelerator="auto",
devices="auto",
)
trainer.fit(model, train_loader, test_loader)
The L.Trainer handles the training loop, callbacks (ModelCheckpoint and EarlyStopping), logging (TensorBoardLogger), and device management.
Running the Code
- Save: Save the code as a Python file (e.g.,
cifar10_cnn.py). - Run: Execute the script from your terminal:
python cifar10_cnn.py - TensorBoard: After training, launch Tensorboard to visualize the results:
tensorboard —logdir lightning_logs/
## Conclusion
This tutorial demonstrated how to build and train a CNN for CIFAR-10 using PyTorch Lightning. The combination of PyTorch Lightning's streamlined training loop, TensorBoard integration, and callbacks like ModelCheckpoint and EarlyStopping makes it easier to develop and optimize deep learning models.
## Further Exploration
* Experiment with different CNN architectures.
* Try various data augmentation techniques.
* Explore different optimizers and learning rate schedules.
* Implement more advanced regularization methods.