- Published at
Sentiment Analysis with BERT: A Practical Guide
This tutorial demonstrates training and validating a sentiment classifier using BERT in PyTorch Lightning. We'll cover data loading, model setup, training, validation, and inference.
- Authors
-
-
- Name
- James Lau
- Indie App Developer at Self-employed
-
Table of Contents
Sentiment Analysis with BERT: A Practical Guide
This blog post walks you through building a sentiment analysis model using BERT (Bidirectional Encoder Representations from Transformers) and PyTorch Lightning. We’ll train the model on the IMDB movie review dataset and then demonstrate how to use it for inference.
Training
The training script leverages the power of transfer learning with pre-trained BERT models. Here’s a breakdown:
- Import Libraries: We start by importing necessary libraries like
torch,pytorch_lightning,transformers, anddatasets. - Device Management: The code checks for GPU availability and utilizes it if present; otherwise, it falls back to CPU. If a TPU is available, the script attempts to install the required packages.
- Model Definition (
SentimentClassifier): This class inherits frompl.LightningModule. It initializes a BERT model (BertForSequenceClassification) and tokenizer (BertTokenizer). Thenum_labels=2indicates we’re classifying sentiment as either positive or negative. - Forward Pass: The
forwardmethod handles the input data, tokenization, and passing it to the BERT model for classification. - Training Step: This step performs the forward pass, calculates the loss using cross-entropy, logs the training loss, and returns the loss value.
- Validation Step: Similar to the training step, this validates the model on a separate dataset and logs the validation loss.
- Optimizer: An AdamW optimizer is used with a learning rate of 2e-5.
- Dataset Loading & DataLoaders: The IMDB dataset is loaded using
datasets.load_dataset, and thenDataLoaders are created for training and validation. - Training Loop: A
pl.Trainerobject orchestrates the training process, iterating through epochs and updating model parameters based on the training data. - Checkpoint Saving: The trained model is saved as a checkpoint (
bert_movie_check.ckpt) for later use.
Validation & Inference
The validation script demonstrates how to load the trained model and perform inference on new sentences:
- Model Loading: The
SentimentClassifier.load_from_checkpointfunction loads the saved checkpoint, restoring the model’s state. - Tokenizer Initialization: The BERT tokenizer is initialized again.
- Inference Function (
predict_sentiment): This function takes a sentence as input, tokenizes it using the tokenizer, and passes it through the loaded BERT model to obtain sentiment predictions. - Prediction & Output: The code predicts the sentiment of a sample sentence and prints the predicted sentiment along with the probabilities for each class (positive or negative).
Code Snippets
# Model Definition
class SentimentClassifier(pl.LightningModule):
def __init__(self, model_name='bert-base-uncased', num_labels=2, learning_rate=2e-5):
super(SentimentClassifier, self).__init__()
self.save_hyperparameters()
self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
self.tokenizer = BertTokenizer.from_pretrained(model_name)
# Inference Function
def predict_sentiment(model, tokenizer, sentence):
model.eval()
inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)
prediction = torch.argmax(probabilities, dim=1).item()
return prediction, probabilities
Conclusion
This tutorial provides a practical guide to sentiment analysis using BERT and PyTorch Lightning. By leveraging transfer learning, we can achieve high accuracy with relatively little training data. The code provided is readily adaptable for other text classification tasks.