SmolLM: Implementing, Fine-Tuning, and Aligning a LLM for Grammatical Error Correction
Introduction
In the rapidly evolving field of Natural Language Processing (NLP), language models have become indispensable tools for a variety of applications, from text generation to grammatical error correction. This blog post delves into the journey of implementing a custom version of the SmolLM-135M model, fine-tuning it for grammatical error correction using the Grammarly CoEdIT dataset, and aligning the fine-tuned model through Reinforcement Learning from AI Feedback (RLAIF) using Direct Preference Optimization (DPO).
Overview
The project encompasses three main objectives:
Custom Implementation of SmolLM-135M: Building a custom version of the SmolLM-135M model to gain deeper insights into its architecture and functionalities.
Fine-Tuning for Grammatical Error Correction: Adapting the pre-trained model to correct grammatical errors by training it on a specialized dataset.
Alignment through DPO: Enhancing the model's performance and alignment with human preferences using DPO, a reinforcement learning technique.
1. Custom Implementation of SmolLM-135M
1.1 Understanding SmolLM-135M
SmolLM-135M is a compact language model available on HuggingFace. It serves as an excellent foundation for experimentation due to its manageable size and robust capabilities.
1.2 Model Architecture
The custom implementation involves replicating the essential components of SmolLM-135M:
Embedding Layer: Converts input tokens into continuous embeddings.
Transformer Blocks: Stacks of decoder layers, each comprising:
Rotary Positional Embeddings: Incorporates positional information without explicit positional tokens.
Self-Attention Mechanism: Allows the model to focus on different parts of the input sequence.
Multi-Layer Perceptron (MLP): Processes the outputs from the attention mechanism.
Normalization Layers: Ensures stable training through techniques like RMSNorm.
Output Layer: Maps the processed embeddings back to the token space for prediction.
1.3 Key Components Explained
Rotary Embeddings
Rotary embeddings integrate positional information by rotating the query and key vectors in the self-attention mechanism. This approach avoids the need for absolute positional embeddings and improves extrapolation to longer sequences.
RMSNorm
Root Mean Square Layer Normalization (RMSNorm) normalizes the input using the root mean square, providing a more stable alternative to LayerNorm, especially in deep networks.
RopeAttention
This custom attention mechanism incorporates rotary embeddings directly into the attention calculation, optimizing the model's ability to capture positional dependencies.
1.4 Loading Pre-Trained Weights
After defining the custom architecture, we load the pre-trained weights from the official SmolLM-135M model. This step initializes the model with learned parameters, ensuring that our custom implementation aligns with the original model's performance.
2. Testing the Custom Model
Before proceeding to fine-tuning, it's crucial to validate the custom model's correctness.
2.1 Verification Process
Token-by-Token Generation: We generate text outputs for given prompts, comparing them between the custom model and the reference model.
Prompt Examples: Testing with diverse prompts like questions and statements to ensure consistency.
Output Analysis: Checking for alignment in generated tokens and overall coherence.
2.2 Results
The custom model's outputs closely match those of the reference model, confirming that the implementation is correct and ready for fine-tuning.
3. Fine-Tuning for Grammatical Error Correction
3.1 Dataset Preparation
We utilize the Grammarly CoEdIT dataset, specifically focusing on grammatical error correction (GEC) tasks.
Filtering the Dataset
Task Selection: Extracting samples labeled as 'gec' to target grammatical corrections.
Data Formatting: Structuring the input-output pairs to fit the model's expected format.
3.2 Tokenization and Data Collation
Tokenizer Initialization: Using the tokenizer associated with SmolLM-135M and setting appropriate padding tokens.
Custom Data Collator: Ensuring that the data batches are correctly formatted, handling padding and special tokens.
3.3 Training Configuration
Hyperparameters:
Learning rate: 3e-5
Batch size: 16 (with gradient accumulation)
Epochs: 1 (can be increased for better performance)
Training Loop: Utilizing the SFTTrainer for supervised fine-tuning.
3.4 Fine-Tuning Process
Model Training: The model learns to correct grammatical errors by minimizing the loss between its predictions and the ground truth.
Validation: Monitoring the model's performance on a separate validation set to prevent overfitting.
3.5 Inference and Testing
Inference Function: Creating a function to input raw sentences and receive corrected outputs.
Sample Test: Inputting sentences with grammatical errors and observing the corrections made by the model.
3.6 Evaluation
Metric Used: BLEU score, a common metric for evaluating text generation tasks.
Results: After fine-tuning, the model achieves a BLEU score of approximately 0.48 on the validation set, indicating effective grammatical corrections.
4. Creating a Preference Optimization Dataset
To further refine the model's outputs based on human preferences, we create a dataset that reflects preferred corrections.
4.1 Generating Variants
Output Variants: For each input sentence, we generate two corrected versions using different decoding strategies (greedy decoding and sampling).
Diversity: This approach introduces variations in the corrections, some closer to the ground truth than others.
4.2 Annotating Preferences
Edit Distance Calculation: Measuring the similarity between each variant and the ground truth correction.
Preference Assignment: Choosing the variant closer to the ground truth as the preferred output.
Dataset Structure: Compiling the input sentence, the preferred correction, and the less preferred correction into a new dataset.
4.3 Saving the Dataset
The preference optimization dataset is saved for reuse in the alignment phase, ensuring reproducibility and ease of access.
5. Aligning the Model through Direct Preference Optimization
5.1 Understanding DPO
Direct Preference Optimization (DPO) is a reinforcement learning technique that aligns the model's outputs with human preferences without requiring explicit reward modeling.
5.2 Training Process
Reference Model: Using the fine-tuned model as a baseline for comparison.
DPO Trainer Configuration:
Learning rate: 5e-6
Batch size: 16
Epochs: 1
Optimization: The model learns to prefer outputs that align more closely with human preferences as defined in the preference dataset.
5.3 Evaluation
Post-DPO BLEU Score: The model achieves an improved BLEU score of approximately 0.50.
Performance Comparison: Demonstrates a notable enhancement over the fine-tuned model without DPO.
Conclusion
This project showcases the end-to-end process of implementing, fine-tuning, and aligning a language model for a specific NLP task.
Custom Implementation: Building the model from scratch provides a deeper understanding of its inner workings.
Fine-Tuning: Adapting the model to a specific task improves its performance in that domain.
Alignment with Human Preferences: Utilizing techniques like DPO ensures that the model's outputs are not only correct but also align with human expectations.
By following these steps, we enhance the SmolLM-135M model's ability to correct grammatical errors effectively, making it a valuable tool for applications that require high-quality text generation and correction.
Future Work
Extended Training: Increasing the number of training epochs and experimenting with different hyperparameters could further improve performance.
Dataset Expansion: Incorporating additional datasets or augmenting existing ones may enhance the model's generalization capabilities.
Real-world Deployment: Integrating the model into applications for live grammatical error correction, monitoring its performance, and collecting user feedback for continuous improvement.
References
MY OTHER PROJECT