Fine-Tuning GPT-2 for SMS Spam Classification
π© Fine-Tuning GPT-2 for SMS Spam Classification
The recent rise in spam messages has led to increased interest in leveraging advanced language models to counter the problem. This blog explores how to fine-tune a pre-trained GPT-2 model for SMS spam classification, covering the entire process of setting up, training, and evaluating a model to distinguish between spam and ham (non-spam) messages. Here, weβll provide insights on the project structure, dataset, and visualized results.
π― Project Overview
This project fine-tunes GPT-2, a popular language model, to classify SMS messages effectively as spam or ham. The fine-tuning process builds on top of GPT-2βs language understanding, adapting it for classification. By the end, the model will be able to label new SMS messages based on the training it received from a well-known SMS spam dataset.
π Directory Structure
Organizing files and code is essential for a streamlined project experience. Hereβs a structured breakdown of our files:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
llm-finetuning-for-classification
βββ README.md # Project documentation
βββ main.py # Script for training, validation, and testing
βββ review_classifier.pth # Model checkpoint
βββ requirements.txt # Dependencies
βββ sms_spam_collection/ # Dataset folder
β βββ SMSSpamCollection.tsv # SMS dataset file
βββ classifier/ # Spam prediction utilities
β βββ predictor.py # Spam prediction functions
β βββ spam_classifier.py # Classification helper functions
βββ gpt2/ # GPT-2 model files
βββ nets/ # Model architecture and configurations
βββ resources/ # Results resources for visualization
βββ utils/ # Utility scripts
Key Files
- main.py: Core script that manages training, validation, and testing.
- review_classifier.pth: The saved model checkpoint, ready for reuse or deployment.
- predictor.py: Prediction utilities used to classify new messages.
- resources/: Folder containing accuracy and loss plots, visualizing the training process.
π§ Getting Started
To begin, clone the repository and install the necessary packages.
1
2
3
git clone https://github.com/AbhijitMore/llm-finetuning-for-classification.git
cd llm-finetuning-for-classification
pip install -r requirements.txt
π The Dataset
This project uses the SMS Spam Collection dataset from the UCI Machine Learning Repository. It contains labeled SMS messages, divided into two columns:
- Label: Specifies if the message is spam or ham.
- Text: The actual SMS message content.
Once downloaded, place the dataset in the sms_spam_collection/SMSSpamCollection.tsv
file path.
π Model Training & Evaluation
Training the Model
Run the main.py
script to train the model on the SMS dataset:
1
python main.py
This script handles data loading, preprocessing, and fine-tuning of GPT-2 for classification. Loss and accuracy are computed, and graphs are generated throughout training.
Testing & Results
After training, the model automatically tests on a designated dataset, with results outputted to the console. Youβll also find saved plots of accuracy and loss in the resources/
folder:
- Accuracy Plot (
accuracy-plot.png
): Visual representation of model accuracy over epochs. - Loss Plot (
loss-plot.png
): Indicates model loss progression, helping assess convergence.
Console Output
During training, the console provides real-time updates on metrics like loss and accuracy, giving insights into the modelβs performance as it improves over epochs.
π Results
Hereβs a look at the final output:
- Console Output: Snapshots of training progression are saved in
resources/console_output.png
. - Performance Graphs: The loss and accuracy plots summarize model stability and improvements throughout training.
π€ Contributions
Contributions are welcomed! Whether suggesting a feature, reporting a bug, or submitting a pull request, all efforts are valued. Your input will help improve the SMS Spam Classification project and further empower GPT-2βs application to real-world classification tasks.