Shrinking DistilBERT: A Deep Dive

Research Log // NLP Architecture & Optimization

88.3%

Base Acc

2.1×

Compression

31.1%

Sparsity

63%

Lat. Drop

The Core Objective

Transformers like BERT are mathematically brilliant, but their heavy reliance on dense matrix multiplications makes them slow, memory-intensive, and expensive to deploy in the real world. DistilBERT is inherently a compressed version of BERT—but can we push it further? The goal of this research is to apply a rigorous three-stage compression architecture to distilbert-base-uncased for news topic classification (using the AG News dataset) to see how small and fast we can make it before structural integrity fails.

Step 1: Baseline Training

We begin by establishing a control. We fine-tune the base DistilBERT model on the AG News dataset (categorizing text into World, Sports, Business, or Sci/Tech). This fully trained baseline achieves an accuracy of 88.30%. Crucially, this robust model will act as the "Teacher" in our distillation phase later on.

# Standard fine-tuning of the baseline model
model_baseline = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=4
).to(device)

# Trained with fp16 mixed precision for GPU acceleration
training_args = TrainingArguments(
    output_dir="./results_baseline",
    num_train_epochs=3,
    per_device_train_batch_size=32,
    fp16=True 
)

Step 2: Dual-Stage Pruning

Pruning physically removes the least critical mathematical operations from the network. In a Transformer architecture, not all attention heads or feed-forward weights contribute equally to the final prediction. We apply two targeted pruning methods:

1. Attention Head Pruning

We pass 20 mini-batches of data through the model and compute an "importance score" for every single attention head by averaging their absolute attention weights. Once ranked, we physically delete the Query, Key, Value, and Output projection matrices for the bottom 60% of these heads.

2. Magnitude Pruning

Next, we target the Feed-Forward Network (FFN) and the classification head. We use L1 unstructured pruning to zero-out the absolute smallest weights across these layers. This results in a highly sparse matrix configuration, achieving a combined 31.1% sparsity across roughly 58 million parameters.

1. Score Calculate absolute mean weights of all attention heads.
--->
2. Sever Delete bottom 60% of Q,K,V projection matrices.
--->
3. Zero-Out L1 unstructured pruning on FFN. 31.1% sparsity achieved.
# Computing head importance via absolute mean
def compute_head_importance(model, dataloader, n_batches=20):
    # ... calculates average absolute attention weights ...
    return head_importance

# Prune FFN layers using PyTorch built-in unstructured pruning
prune.l1_unstructured(module, name="weight", amount=0.6)
prune.remove(module, "weight")

Step 3: Knowledge Distillation (KD) Recovery

When you aggressively rip out 60% of a model's attention heads, the accuracy craters. Knowledge Distillation (KD) is the rescue mission.

The Mathematics of KD

Instead of retraining our newly pruned "Student" model purely on hard binary labels (e.g., [0, 1, 0, 0] for Sports), we train it to mimic the complex probability distributions of our highly-accurate "Teacher" baseline model (e.g., [0.05, 0.85, 0.08, 0.02]).

We use a Temperature scaling factor (T=4.0). By dividing the logits by 4 before applying the Softmax function, we "flatten" the probability curve. This exposes the hidden relationships between classes (e.g., showing the student that a 'Sci/Tech' article shares more vocabulary with 'Business' than it does with 'Sports'). We calculate the loss using Kullback-Leibler (KL) Divergence (weighted at 70%) mixed with standard Cross-Entropy (weighted at 30%).

Teacher Model Frozen baseline outputs dense, accurate Logits.
+
Temperature (T=4) Logits are divided by 4 to soften the Softmax distribution.
--->
KL Divergence Student updates weights to mimic the Teacher's curve.
def distillation_loss(student_logits, teacher_logits, labels, temp=4.0, alpha=0.7):
    # KL Divergence aligns student with teacher's softened outputs
    kd_loss = F.kl_div(
        F.log_softmax(student_logits / temp, dim=1),
        F.softmax(teacher_logits / temp, dim=1),
        reduction="batchmean"
    ) * (temp ** 2)
    
    # Combined with standard Cross Entropy on ground truth
    ce_loss = F.cross_entropy(student_logits, labels)
    return alpha * kd_loss + (1 - alpha) * ce_loss

Note: During this phase, it is critical to re-apply the zero-masks to the weights after every single optimizer step, otherwise the model will attempt to "un-prune" itself to lower the loss.

Step 4: Dynamic INT8 Quantization

The final blow to the model's footprint is Quantization. Neural networks natively store their weights as 32-bit floating-point numbers (FP32). We dynamically cast all Linear layer weights down to 8-bit integers (INT8).

Because an INT8 value takes exactly one-quarter of the memory of an FP32 value, the model size on disk shrinks drastically. We use Dynamic Quantization, meaning the activation scales are calculated on-the-fly at runtime, avoiding the need for a separate calibration dataset. This targets CPU architecture directly, drastically increasing memory bandwidth efficiency and resulting in lightning-fast CPU inference.

# Applying dynamic quantization to target CPU execution
import torch.ao.quantization

model_quantized = torch.ao.quantization.quantize_dynamic(
    model_pruned.cpu(), 
    {torch.nn.Linear}, 
    dtype=torch.qint8
)

The Final Architecture Benchmarks

To guarantee scientific fairness, the latency benchmarks for all three model variants were evaluated locally on the exact same CPU. Notice how the final quantized model is over twice as small, but only suffers a ~3.5% drop in total accuracy.

Model Variant Accuracy File Size Sparsity Latency (CPU)
Baseline (DistilBERT) 88.30% 255.46 MB 0.0% 200.34 ms
+ Pruned + KD 87.20% 223.93 MB 31.1% 107.16 ms
+ Quantized (INT8) 84.80% 124.38 MB N/A 74.03 ms

Limitations & Reality Check

While the results are excellent, developing this architecture inside a constrained environment like Google Colab exposes some hardware realities. Calculating head importance across layers is incredibly tensor-heavy; you have to explicitly manage GPU caching to prevent Out-of-Memory (OOM) failures. Furthermore, PyTorch’s dynamic quantization natively targets CPU execution, making direct GPU latency comparisons for the final INT8 model impossible without specialized edge-hardware or tensor-rt implementations.