Can you give an intuition as to why it's faster? I would have thought regardless how many you run in parallel, the successful check has to execute the full model to generate the full sequence so you will have exactly the same time needed? Or is it by process of elimination so it terminates early once it eliminates the non-viable choices? (in which case, how do you guarantee the correct output was speculatively generated at all to be the last survivor?)
The big target model calculates
P(d1)
P(d2|d1)
P(d3|d1 d2)
In parallel. If we were just greedy decoding it would be simple. Just stop when the draft model doesn’t predict the most likely token as judged by the target model. At that point, append the correct token from the target model and kick off both models again in parallel.
In practice we aren’t using greedy decoding. We are sampling and we need to match the target model’s distribution. To do this, we accept tokens from the draft model probabilistically, which is possible because we have the logits of both the draft model and the target at that point. The ratio of their softmax probabilities is used for this.
You are right that actually accepting tokens has to happen sequentially but that’s a heck of a lot faster than a forward pass.
So this is a case of trading off idle compute capacity that's waiting for the bottleneck (memory access).
The reason it's designed this way is a bit subtle but it has the advantage during training that you can use a single block of 10 tokens to generate 9 training examples in parallel, so it's highly efficient. This efficiency is basically the main benefit of transformers - the algorithm parallelizes really well and that's what allowed the scale up to large language models as opposed to the previous reality of just language models.
The blog post does discuss why MTP is faster but it's maybe a bit hard to understand if you haven't studied LLM internals. During inference the hardware has arithmetic units idling because they spend so much time waiting for the weight matrices to get moved closer to the processors. Because data movement and computation can be overlapped, if you can reuse the same loaded data for multiple calculations at once you're winning - it's free latency-wise because you're just exploiting previously idle resources (it's not free in terms of energy).
Speculative decoding and MTP exploit this to run the model in parallel on several tokens at once. Say your context window contains "The United". The KV cache has been populated by the main model for this set of tokens. The draft model is given "The United" and predicts " States of America" in one forward pass (this part where it can predict multiple tokens at once with a single pass is the MTP part). Then the main model is given the KV cache from last time along with " States of America". In its own forward pass it can then compute in parallel the completions of both "The United", "The United States", "The United States of" and "The United States of America" (the last one might be an eos token indicating it wants to stop talking.). That's the speculative decoding part.
Now you decode the main model at each position (look at the token probabilities and pick one according to some decoding strategy). It's possible the main model didn't pick " States" at all, or picked " States", but then its prediction diverged e.g. if it wants to say "The United States is a country". So you just select the tokens that match and toss all the tokens starting from the one that didn't. Repeat.
The parallelism comes almost for free because the same weight matrices can be reused multiple times before they're swapped out for the next.
But I think the key is that in the standard autoregressive case we get memory bandwidth bound, so there are tons of idle compute resources. And so checking multiple tokens is cheap because we can batch and thus reuse the read weights for multiple tokens.
The verification step is similar to a prefill with a small batch size. The difference is what we do with the generated logits.
Most of the complexity in implementing a simple toy version comes from having to get the KV cache back into a good state for the next cycle (e.g. if only the first half of your draft tokens were correct).
Right, this is the same way batching works. It's "free" until we exhaust available compute resources, at which point decode throughput becomes compute bound. (This is a good place to be, because scaling out compute is a lot easier than adding fast VRAM.) This is why MTP is mostly useful when you have one or few users, which means compute is abundant. When you're running large batches you're better off using that compute to grow your batch size.
Of course, batch size is usually limited by things like bulky KV caches. So perhaps MTP has some residual use in that setting. But if you're sharing cached context in a subagent swarm, or running a model like the recent DeepSeek V4 with its tiny KV cache, you can go a lot further in processing a larger batch.