The Core Objective
Transformers like BERT are mathematically brilliant, but their heavy reliance on dense matrix multiplications makes them slow, memory-intensive, and expensive to deploy in the real world. DistilBERT is inherently a compressed version of BERT—but can we push it further? The goal of this research is to apply a rigorous three-stage compression architecture to distilbert-base-uncased for news topic classification (using the AG News dataset) to see how small and fast we can make it before structural integrity fails.
Step 1: Baseline Training
We begin by establishing a control. We fine-tune the base DistilBERT model on the AG News dataset (categorizing text into World, Sports, Business, or Sci/Tech). This fully trained baseline achieves an accuracy of 88.30%. Crucially, this robust model will act as the "Teacher" in our distillation phase later on.
# Standard fine-tuning of the baseline model
model_baseline = DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=4
).to(device)
# Trained with fp16 mixed precision for GPU acceleration
training_args = TrainingArguments(
output_dir="./results_baseline",
num_train_epochs=3,
per_device_train_batch_size=32,
fp16=True
)
Step 2: Dual-Stage Pruning
Pruning physically removes the least critical mathematical operations from the network. In a Transformer architecture, not all attention heads or feed-forward weights contribute equally to the final prediction. We apply two targeted pruning methods:
1. Attention Head Pruning
We pass 20 mini-batches of data through the model and compute an "importance score" for every single attention head by averaging their absolute attention weights. Once ranked, we physically delete the Query, Key, Value, and Output projection matrices for the bottom 60% of these heads.
2. Magnitude Pruning
Next, we target the Feed-Forward Network (FFN) and the classification head. We use L1 unstructured pruning to zero-out the absolute smallest weights across these layers. This results in a highly sparse matrix configuration, achieving a combined 31.1% sparsity across roughly 58 million parameters.
# Computing head importance via absolute mean
def compute_head_importance(model, dataloader, n_batches=20):
# ... calculates average absolute attention weights ...
return head_importance
# Prune FFN layers using PyTorch built-in unstructured pruning
prune.l1_unstructured(module, name="weight", amount=0.6)
prune.remove(module, "weight")
Step 3: Knowledge Distillation (KD) Recovery
When you aggressively rip out 60% of a model's attention heads, the accuracy craters. Knowledge Distillation (KD) is the rescue mission.
The Mathematics of KD
Instead of retraining our newly pruned "Student" model purely on hard binary labels (e.g., [0, 1, 0, 0] for Sports), we train it to mimic the complex probability distributions of our highly-accurate "Teacher" baseline model (e.g., [0.05, 0.85, 0.08, 0.02]).
We use a Temperature scaling factor (T=4.0). By dividing the logits by 4 before applying the Softmax function, we "flatten" the probability curve. This exposes the hidden relationships between classes (e.g., showing the student that a 'Sci/Tech' article shares more vocabulary with 'Business' than it does with 'Sports'). We calculate the loss using Kullback-Leibler (KL) Divergence (weighted at 70%) mixed with standard Cross-Entropy (weighted at 30%).
def distillation_loss(student_logits, teacher_logits, labels, temp=4.0, alpha=0.7):
# KL Divergence aligns student with teacher's softened outputs
kd_loss = F.kl_div(
F.log_softmax(student_logits / temp, dim=1),
F.softmax(teacher_logits / temp, dim=1),
reduction="batchmean"
) * (temp ** 2)
# Combined with standard Cross Entropy on ground truth
ce_loss = F.cross_entropy(student_logits, labels)
return alpha * kd_loss + (1 - alpha) * ce_loss
Note: During this phase, it is critical to re-apply the zero-masks to the weights after every single optimizer step, otherwise the model will attempt to "un-prune" itself to lower the loss.
Step 4: Dynamic INT8 Quantization
The final blow to the model's footprint is Quantization. Neural networks natively store their weights as 32-bit floating-point numbers (FP32). We dynamically cast all Linear layer weights down to 8-bit integers (INT8).
Because an INT8 value takes exactly one-quarter of the memory of an FP32 value, the model size on disk shrinks drastically. We use Dynamic Quantization, meaning the activation scales are calculated on-the-fly at runtime, avoiding the need for a separate calibration dataset. This targets CPU architecture directly, drastically increasing memory bandwidth efficiency and resulting in lightning-fast CPU inference.
# Applying dynamic quantization to target CPU execution
import torch.ao.quantization
model_quantized = torch.ao.quantization.quantize_dynamic(
model_pruned.cpu(),
{torch.nn.Linear},
dtype=torch.qint8
)
The Final Architecture Benchmarks
To guarantee scientific fairness, the latency benchmarks for all three model variants were evaluated locally on the exact same CPU. Notice how the final quantized model is over twice as small, but only suffers a ~3.5% drop in total accuracy.
| Model Variant | Accuracy | File Size | Sparsity | Latency (CPU) |
|---|---|---|---|---|
| Baseline (DistilBERT) | 88.30% | 255.46 MB | 0.0% | 200.34 ms |
| + Pruned + KD | 87.20% | 223.93 MB | 31.1% | 107.16 ms |
| + Quantized (INT8) | 84.80% | 124.38 MB | N/A | 74.03 ms |
Limitations & Reality Check
While the results are excellent, developing this architecture inside a constrained environment like Google Colab exposes some hardware realities. Calculating head importance across layers is incredibly tensor-heavy; you have to explicitly manage GPU caching to prevent Out-of-Memory (OOM) failures. Furthermore, PyTorch’s dynamic quantization natively targets CPU execution, making direct GPU latency comparisons for the final INT8 model impossible without specialized edge-hardware or tensor-rt implementations.