BACK TO BLOG

Adaptive Inference at the Edge: Speculative Decoding and KV-Cache Compression

The problem with running LLMs on edge devices isn't the model — it's the runtime. Most inference stacks are designed for data centers: 40GB of HBM, 400W power envelopes, NVLink between GPUs. When you're working with a Jetson Orin or a high-end laptop with 8GB of shared VRAM, those assumptions don't hold. The model fits. The naïve runtime doesn't.

air-runtime is a Python project that packages three techniques — smart routing, speculative decoding, and KV-cache compression — into a single coherent runtime. None of these are new ideas. What's new is combining them in a way that's practical for constrained hardware, without requiring custom CUDA kernels or a PhD to configure.

Technique 1: Smart Routing

Not all queries need the same model. A request like "what's 12 times 7?" doesn't need a 70B parameter model. A request like "refactor this Rust async executor to use a work-stealing scheduler" probably does. Smart routing classifies queries by complexity and dispatches them to the smallest model capable of handling them.

The router is a lightweight classifier — a fine-tuned DistilBERT or a simple logistic regression over TF-IDF features, depending on latency budget. It outputs a complexity score that maps to one of three tiers:

class QueryRouter: def __init__(self, classifier_path: str): self.clf = joblib.load(classifier_path) self.vectorizer = TfidfVectorizer(max_features=2048) def route(self, query: str) -> ModelTier: features = self.vectorizer.transform([query]) score = self.clf.predict_proba(features)[0] # Returns Tier enum based on confidence thresholds return self._score_to_tier(score)

In practice, roughly 60% of queries route to Tier 0, 30% to Tier 1, and only 10% reach Tier 2. Since smaller models are 5–10x faster on the same hardware, this cuts median latency dramatically without meaningfully degrading output quality on the queries that don't need the big model.

Technique 2: Speculative Decoding

Speculative decoding exploits an asymmetry in transformer inference: generating N tokens auto-regressively takes N forward passes of the large model, but verifying N tokens takes only one forward pass — because the verifier can process all positions in parallel.

The protocol:

  1. A small "draft" model generates a sequence of K candidate tokens quickly.
  2. The large "verifier" model evaluates the entire draft in a single forward pass.
  3. Tokens are accepted greedily from left to right until the first mismatch. The verifier corrects the first mismatched position and the cycle repeats.
async def speculative_decode( draft_model, verifier_model, prompt: str, max_tokens: int, K: int = 5 ): tokens = tokenize(prompt) while len(tokens) < max_tokens: # Draft K tokens with small model draft = await draft_model.generate(tokens, K) # Verify entire draft in one pass logits = await verifier_model.forward(tokens + draft) accepted = accept_tokens(draft, logits[len(tokens):]) tokens += accepted if len(accepted) < K: break # early stop on mismatch return decode(tokens)

The speedup depends on the draft acceptance rate. If the draft model and verifier agree on most tokens (common for well-aligned model pairs), you get close to K× speedup. For typical English prose or code, acceptance rates of 70–85% are achievable, yielding 3–4× end-to-end throughput improvement.

Model pair selection matters. The draft model should be in the same family as the verifier (e.g., Llama 3.2 1B drafting for Llama 3.1 8B) to maximise token distribution alignment. Mismatched families produce low acceptance rates and can be slower than baseline.

Technique 3: KV-Cache Compression

The KV cache is the transformer's memory. For each token in the context, attention layers store a key vector and a value vector. The cache grows linearly with context length: for a 7B model with 32 layers and 4096-dim hidden state, a 4K context consumes roughly 2GB of GPU memory. On an 8GB device, that's 25% of your total budget for context alone.

Head pruning

Not all attention heads contribute equally. Studies consistently show that 20–40% of heads can be zeroed out with less than 1% perplexity degradation. I implement a calibration pass on a small representative dataset to compute per-head importance scores, then prune the bottom quartile.

StreamingLLM attention sinks

For long-context inference, I adopt the attention sink approach from StreamingLLM: always retain the first 4 tokens of the sequence (the "sink" tokens that disproportionately accumulate attention weight) plus a sliding window of recent tokens. This bounds cache size to O(window_size) instead of O(sequence_length), enabling theoretically infinite context on fixed memory.

class SinkKVCache: def __init__(self, sink_size=4, window_size=512): self.sink_size = sink_size self.window_size = window_size self.cache = [] def update(self, new_kv): if len(self.cache) <= self.sink_size: self.cache.append(new_kv) else: sinks = self.cache[:self.sink_size] window = self.cache[self.sink_size:][-self.window_size + 1:] self.cache = sinks + window + [new_kv]

Results on a Jetson Orin

Benchmarking on a Jetson AGX Orin (64GB unified memory, 2048-core Ampere GPU) with a 7B quantised model (INT4):

Combined, the runtime achieves approximately 2.8× end-to-end throughput improvement on a mixed workload vs. naïve autoregressive generation with the large model only.

The real win isn't just speed. It's the memory savings that allow you to run a 7B model alongside other processes on a shared-memory edge device without swapping. Latency matters, but stability over hours of continuous operation matters more.