upvote
If you read the section "Richer attention mechanisms", you can see, no, the mechanism is not generally useable (it requires significant modification to become differentiable). They later speculate:

    While we do not yet know whether exact softmax attention
    can be maintained with the same efficiency, it is easy to
    approximate it with k-sparse softmax attention: retrieve
    the top-k keys and perform the softmax only over those
but if you have played around with training models that use e.g. topk or other hard thresholding operations in e.g. PyTorch (or just think about how many gradients become zero with such an operation) you know that these tend to work only in extremely limited / specific cases, and make training even more finicky than it already is.
reply
I saw that, but the image included nearby made it look like it might be plausible to replace the 1D line around their points with a pretty narrow 2D area. This could still be a somewhat effective filter, right?
reply
The problem is they are talking about tricks for compiling VMs into transformer weights, which is basically unrelated to actually training transformers on data via gradient descent. Once you get into this actual messy practical reality, you have non-trivial stuff like sparsemax and the Gumbel-Softmax trick to get some desirable improvements to things like the softmax, without all the gradient destruction of things like top-k approaches, but usually at pretty serious other costs (most approaches using Gumbel-Softmax I have read essentially create a bi-level optimization problem that is claimed to be "solved" by some handwavey annealing, but which is clearly highly unstable and hard to tune. I don't know if things have improved here since I last read on it).

So the issue isn't if there aren't ways to effectively approximate their approach, from a strictly numerical approximation standpoint, it is that other factors matter much more in optimization when training on actual data.

reply