The draft model quickly generates draft-token 1.
The main model then starts working on two tokens in parallel. It calculates token 1 based on the context, and token 2 based on the context + draft-token 1.
Once the two tokens have been generated, you can check whether the draft-token 1 from the draft model matches token 1 from the main model.
If they match, you have just calculated two tokens in the time it takes to generate one, because the calculation was done in parallel. If they do not match, delete token 2 and generate it again. Since you have already generated the correct token 1 with the big model, you can use the context + token 1 (from the main model). This takes more time, but the result is always the same.
Inference parameters select a token using those.
You can just select the top token all the time or you can do it probabilistically.
How you do that in both the speculative decoding and the main inference changes how likely you get the exact same tokens. And then you can choose to accept only if the token matches exactly, or you can choose to accept if it was reasonably likely to be chosen.
Let's say the main model picked the 2nd most likely token and speculative picked the most likely. You can reject that - but you get less speed up. You can accept it, you get more speed up, but you do change the output. You risk the distribution of your outputs not being what you hope.
I am simplifying. I know in https://arxiv.org/pdf/2302.01318 they specify a probability that you reject a token.
As far as I know, this is not used in practice. Currently popular implementations always match the main model output, and the draft model only affects the speed.
Matching token that would've been picked without speculative decoding. That seems to be more or less agreed upon.
e.g. vLLM docs list tests they run to ensure that output doesn't change if spec. decoding is used: https://github.com/vllm-project/vllm/blob/main/docs/features...
But introducing some threshold to accept other high probability tokens is interesting idea.
The paper they link to in that first paragraph says you compare logits to accept or reject.
It's like branch prediction - the CPU predicts what branch you'll take and starts executing it. Later you find out exactly what branch you took. If the prediction was correct, the speculative executed code is kept. If the prediction was wrong, it's thrown away, the pipeline is flushed, and the execution resumes from the branch point.
The same with this thing: 3 tokens, A-B-C were "predicted", you start computing ALL them 3 at the same time, hoping that the prediction checks out. And because of the mathematical structure of the transformer, it costs you almost the same to compute 3 tokens at a time or just one - you are limited by bandwidth, not compute. But CRITICALLY, each token depends on all the previous ones, so if you predicted wrongly one of the tokens, you need to discard all tokens predicted after (flush the pipeline). This is why a prediction is required and why you can't always compute 3 tokens simultaneously - the serial dependency between consecutive tokens. If you were to start computing 3 tokens simultaneously without a prediction, for token C you need to assume some exact values for tokens A and B, but those were not computed yet! But if they were speculatively predicted you can start and hope the prediction was correct.