Back to Blog

Optimizing Transformer Inference: From 200ms to 15ms

December 1, 20248 min read
TransformersOptimizationONNXInferenceProduction ML

The Problem

You've fine-tuned a BERT model that achieves great accuracy on your task. But when you deploy it, inference takes 200ms per request — way too slow for a real-time API with a 50ms SLA. Sound familiar?

This post walks through the exact optimization pipeline I used to reduce transformer inference latency by 92% while retaining 98% of the original accuracy.

Baseline Measurement

Always start by profiling your baseline. Here's what our starting point looked like:

import torch
import time

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Benchmark
inputs = tokenizer("Sample input text for benchmarking", return_tensors="pt")
times = []
for _ in range(100):
    start = time.perf_counter()
    with torch.no_grad():
        outputs = model(**inputs)
    times.append(time.perf_counter() - start)

print(f"Mean latency: {sum(times)/len(times)*1000:.1f}ms")
# Output: Mean latency: 198.3ms (CPU)

Step 1: ONNX Export

Converting to ONNX format alone provides significant speedup through graph optimizations:

from optimum.exporters.onnx import main_export

main_export(
    model_name_or_path="./fine-tuned-bert",
    output="./onnx-model",
    task="text-classification",
)

Result: 198ms → 85ms (57% reduction)

Step 2: Quantization

Dynamic INT8 quantization reduces model size and leverages integer arithmetic:

from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="./onnx-model/model.onnx",
    model_output="./onnx-model/model_quantized.onnx",
    weight_type=QuantType.QInt8,
)

Result: 85ms → 32ms (62% further reduction)

Step 3: Knowledge Distillation

For the final push, we trained a 4-layer DistilBERT student model:

ModelParamsAccuracyLatency
BERT-base110M92.3%198ms
ONNX + Quantized BERT110M91.8%32ms
Distilled + Quantized28M90.5%15ms

Key Takeaways

  1. Profile before optimizing — know where your bottleneck actually is
  2. ONNX export is low-hanging fruit — always do this first
  3. Quantization is nearly free — minimal accuracy loss for big speedups
  4. Distillation is worth it if you need sub-20ms latency
  5. Batch your requests when possible — GPU utilization matters