In our previous post, we explored computational graphs and autograd in PyTorch. Now, we arrive at the final part of this four-part PyTorch tutorial: the training loop. This is where everything comes together—data, model, and computational graph—to create a functioning neural network.
The Training Process: A Conceptual Overview
Training a model essentially means adjusting its parameters to minimize a target loss function on the provided dataset. Let's revisit the FashionMNIST dataset example from the first part of this series. In this dataset, we have grayscale images of size 28×28 pixels, with each image mapping to one of 10 unique labels, such as Dress, Coat, Sandal, etc.
When we initialize a PyTorch model, all weights start with random values (though there are various initialization strategies to promote faster/more efficient convergence). At this point, if the model processes an image of a Dress, it likely won't output "Dress" as the class label since no training has occurred yet. If it mistakenly predicts the image as "Sandal," this error contributes to the loss. The goal of training is to iteratively update the network parameters to reduce this loss.
This process is accomplished through the backpropagation algorithm, which we covered in the previous blog post on computational graphs. In essence, the algorithm consists of:
- Performing a forward pass through the network, building the computational graph
- Computing the loss at the output
- Computing gradients of the loss with respect to each parameter by traversing the graph backward
With these gradients available, an optimizer can update the network parameters to reduce the training loss.
Loss Functions and Optimizers
Before implementing the training loop, let's discuss two critical components: loss functions and optimizers.
Common Loss Functions
The loss function quantifies how far our model's predictions are from the ground truth. Here are some commonly used loss functions in PyTorch:
-
Cross-Entropy Loss (
nn.CrossEntropyLoss
): Used for classification problems. It combinesnn.LogSoftmax
andnn.NLLLoss
in a single class, making it ideal for multi-class classification tasks like our FashionMNIST example. -
Mean Squared Error (
nn.MSELoss
): Used for regression problems. It measures the average squared difference between the estimated values and the actual value. -
Binary Cross-Entropy (
nn.BCELoss
): Used for binary classification problems where each output is independently classified. -
L1 Loss (
nn.L1Loss
): Also known as Mean Absolute Error, it's less sensitive to outliers than MSE. -
Kullback-Leibler Divergence (
nn.KLDivLoss
): Measures how one probability distribution diverges from a second expected probability distribution.
For our FashionMNIST classification task, the Cross-Entropy Loss is most appropriate as it's designed for multi-class problems.
Common Optimizers
Optimizers update the model parameters based on the computed gradients. PyTorch offers several optimizers, each with different characteristics:
-
Stochastic Gradient Descent (SGD) (
torch.optim.SGD
): The simplest optimizer that updates parameters in the opposite direction of the gradient. It can include momentum to accelerate training and dampen oscillations. -
Adagrad (
torch.optim.Adagrad
): Adapts the learning rate to the parameters, performing larger updates for infrequent parameters and smaller updates for frequent ones. -
RMSprop (
torch.optim.RMSprop
): Similar to AdaGrad but addresses its radically diminishing learning rates by using a moving average of squared gradients. -
Adam (
torch.optim.Adam
): Combines the benefits of two other extensions of SGD—AdaGrad and RMSProp. It adapts the learning rate for each parameter, usually leading to faster convergence.
For many tasks, Adam is a good default choice due to its adaptive learning rate and generally good performance across different types of neural networks and datasets.
Implementing the Training Loop
With an understanding of loss functions and optimizers, let's implement the training loop step by step:
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
Let's break down this function:
-
We begin by setting our model to training mode with
model.train()
. This is crucial because certain layers like BatchNorm and Dropout behave differently during training versus evaluation. During training, BatchNorm updates its running statistics, and Dropout randomly zeros elements, while during evaluation, BatchNorm uses its accumulated statistics without updates, and Dropout doesn't drop any elements. -
We iterate through our training data batch by batch (a batch is a collection of training examples). For each batch:
- We move the input features and labels to the target device (CPU or GPU)
- We perform a forward pass through the model to get the predictions
- We calculate the loss by comparing the predictions to the true labels
- We call
loss.backward()
to compute gradients via backpropagation - We call
optimizer.step()
to update the model parameters based on the computed gradients - We call
optimizer.zero_grad()
to reset the gradients for the next iteration
-
Periodically, we print the current loss and progress to monitor training.
The Importance of optimizer.zero_grad()
One common question is why we need to call optimizer.zero_grad()
after each parameter update. This is because PyTorch accumulates gradients by default. If we don't reset them, the gradients from the previous batch would be added to the current batch's gradients, which isn't what we want for standard training. We want each batch's gradients to be calculated independently.
Some advanced training techniques do leverage gradient accumulation (by updating parameters only after several batches), but for standard training, we reset gradients after each update.
Validation and Preventing Overfitting
A common pitfall in model training is overfitting, where the model performs extremely well on the training data but poorly on unseen data. To detect and prevent overfitting, we monitor the model's performance on a validation dataset.
Here's how we can implement a validation function:
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
Key aspects of this function:
-
We set the model to evaluation mode with
model.eval()
. As mentioned earlier, this affects the behavior of layers like BatchNorm and Dropout, ensuring they act appropriately for evaluation rather than training. -
We use
torch.no_grad()
to disable gradient computation, which:- Reduces memory usage
- Speeds up computation
- Is unnecessary during evaluation since we're not updating parameters
-
We calculate the average loss and accuracy across all batches to measure the model's performance.
By running this validation after each training epoch, we can monitor for signs of overfitting—specifically, when the training loss continues to decrease while the validation loss plateaus or increases.
Putting It All Together
Now, let's combine these functions into a complete training pipeline:
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
This simple loop:
- Iterates through the specified number of epochs
- For each epoch, runs the training function on the training data
- Evaluates the model on the test data to monitor performance
- Prints progress information
Understanding Epochs
An epoch represents one complete pass through the entire training dataset. Training typically requires multiple epochs because:
- The optimizer takes small steps in the direction of the negative gradient, so multiple passes help converge to a good minimum
- With shuffling, different epoch presents the data in a different order, helping the model learn more robust patterns
- Some patterns may only become apparent after the model has learned other foundational patterns
The number of epochs is a hyperparameter that needs tuning for each specific dataset and model architecture. Too few epochs may result in underfitting, while too many can lead to overfitting.
Complete Implementation
Here's a complete implementation that brings together everything we've discussed across all four parts of this series:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# Download training data from open datasets
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
# Get device for training
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
Hardware Acceleration: Understanding the "device" Parameter
One crucial aspect of the training loop is the device
parameter. PyTorch can run computations on different hardware:
- CPU: The default option, available on all machines.
- CUDA: For NVIDIA GPUs, offering significant speedups for most deep learning workloads.
- MPS: For Apple M-series chips, providing GPU acceleration on Mac computers.
Matrix multiplication operations, which are fundamental to neural networks, are typically much faster on GPUs due to their massively parallel architecture. GPUs contain thousands of cores designed specifically for handling parallel computations, making them ideal for the matrix operations that dominate deep learning workloads.
The device
parameter helps us determine the best available hardware and move our model and data accordingly:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
Once we determine the device, we:
- Move the model to that device with
model.to(device)
- For each batch, move the input features and labels to the same device
This ensures that all computations—forward pass, loss calculation, and backpropagation—happen on the same hardware, leveraging its specific optimizations.
Performance Considerations
While GPUs typically offer significant speedups, there are cases where using a GPU might not be beneficial:
- For very small models, the overhead of transferring data between CPU and GPU might outweigh the benefits of faster computation.
- If your batches are very small, the GPU might not be fully utilized.
- Some operations might not be optimized for GPU execution in certain frameworks.
For large models and datasets, however, GPU acceleration can reduce training time from days to hours or even minutes.
Beyond the Basic Training Loop
The beauty of PyTorch's explicit, "Pythonic" approach is that you can customize the training loop to your heart's content. Here are some common enhancements:
Model Checkpointing
It's often useful to save model snapshots during training:
# After validation at the end of each epoch
if current_val_loss < best_val_loss:
best_val_loss = current_val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_val_loss,
}, 'best_model.pth')
Learning Rate Scheduling
Adjusting the learning rate during training can improve performance:
# Define scheduler after optimizer
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
# Add after validation in the epoch loop
scheduler.step(val_loss)
Early Stopping
To prevent overfitting, you might want to stop training when validation performance stops improving:
patience = 5
early_stopping_counter = 0
best_val_loss = float('inf')
for epoch in range(epochs):
train(...)
val_loss = test(...)
if val_loss < best_val_loss:
best_val_loss = val_loss
early_stopping_counter = 0
# Save best model here
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print(f"Early stopping after {epoch+1} epochs")
break
Gradient Clipping
To prevent exploding gradients, especially in recurrent neural networks:
# After loss.backward() but before optimizer.step()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
These enhancements are easy to implement because PyTorch gives you complete control over the training process.
Conclusion
In this four-part series, we've covered the essential components of PyTorch:
-
In Part 1, we explored Data Handling with
Dataset
andDataLoader
classes, learning how to efficiently load and preprocess data. -
In Part 2, we delved into Model Building, understanding how to define neural network architectures using the
nn.Module
class. -
In Part 3, we uncovered the magic of Computational Graphs and Autograd, which enable automatic differentiation for training deep learning models.
-
Finally, in this post, we've brought everything together with the Training Loop, where data, model, and computational graph work in harmony to train effective neural networks.
What makes PyTorch special is how these components fit together so naturally. The framework strikes a rare balance between flexibility and usability—it offers low-level control when you need it while providing high-level abstractions that handle the complex details of deep learning.
Whether you're implementing standard architectures or pushing the boundaries with novel approaches, the concepts we've covered provide the foundation for your PyTorch journey. The explicit nature of PyTorch not only makes it an excellent learning tool but also a powerful platform for research and production.
I hope this series has given you the confidence to explore PyTorch further and build your own deep learning applications. Remember that deep learning is as much an art as it is a science—experimentation is key, and PyTorch makes that experimentation both intuitive and enjoyable.