Published at

Sentiment Analysis with BERT: A Practical Guide

Sentiment Analysis with BERT: A Practical Guide

This tutorial demonstrates training and validating a sentiment classifier using BERT in PyTorch Lightning. We'll cover data loading, model setup, training, validation, and inference.

Authors
  • avatar
    Name
    James Lau
    Twitter
  • Indie App Developer at Self-employed
Sharing is caring!
Table of Contents

Sentiment Analysis with BERT: A Practical Guide

This blog post walks you through building a sentiment analysis model using BERT (Bidirectional Encoder Representations from Transformers) and PyTorch Lightning. We’ll train the model on the IMDB movie review dataset and then demonstrate how to use it for inference.

Training

The training script leverages the power of transfer learning with pre-trained BERT models. Here’s a breakdown:

  1. Import Libraries: We start by importing necessary libraries like torch, pytorch_lightning, transformers, and datasets.
  2. Device Management: The code checks for GPU availability and utilizes it if present; otherwise, it falls back to CPU. If a TPU is available, the script attempts to install the required packages.
  3. Model Definition (SentimentClassifier): This class inherits from pl.LightningModule. It initializes a BERT model (BertForSequenceClassification) and tokenizer (BertTokenizer). The num_labels=2 indicates we’re classifying sentiment as either positive or negative.
  4. Forward Pass: The forward method handles the input data, tokenization, and passing it to the BERT model for classification.
  5. Training Step: This step performs the forward pass, calculates the loss using cross-entropy, logs the training loss, and returns the loss value.
  6. Validation Step: Similar to the training step, this validates the model on a separate dataset and logs the validation loss.
  7. Optimizer: An AdamW optimizer is used with a learning rate of 2e-5.
  8. Dataset Loading & DataLoaders: The IMDB dataset is loaded using datasets.load_dataset, and then DataLoaders are created for training and validation.
  9. Training Loop: A pl.Trainer object orchestrates the training process, iterating through epochs and updating model parameters based on the training data.
  10. Checkpoint Saving: The trained model is saved as a checkpoint (bert_movie_check.ckpt) for later use.

Validation & Inference

The validation script demonstrates how to load the trained model and perform inference on new sentences:

  1. Model Loading: The SentimentClassifier.load_from_checkpoint function loads the saved checkpoint, restoring the model’s state.
  2. Tokenizer Initialization: The BERT tokenizer is initialized again.
  3. Inference Function (predict_sentiment): This function takes a sentence as input, tokenizes it using the tokenizer, and passes it through the loaded BERT model to obtain sentiment predictions.
  4. Prediction & Output: The code predicts the sentiment of a sample sentence and prints the predicted sentiment along with the probabilities for each class (positive or negative).

Code Snippets

# Model Definition
class SentimentClassifier(pl.LightningModule):
    def __init__(self, model_name='bert-base-uncased', num_labels=2, learning_rate=2e-5):
        super(SentimentClassifier, self).__init__()
        self.save_hyperparameters()
        self.model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
# Inference Function
def predict_sentiment(model, tokenizer, sentence):
    model.eval()
    inputs = tokenizer(sentence, padding=True, truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    probabilities = F.softmax(logits, dim=-1)
    prediction = torch.argmax(probabilities, dim=1).item()
    return prediction, probabilities

Conclusion

This tutorial provides a practical guide to sentiment analysis using BERT and PyTorch Lightning. By leveraging transfer learning, we can achieve high accuracy with relatively little training data. The code provided is readily adaptable for other text classification tasks.

Sharing is caring!