if you are also curious what sparse attention mask is and how it works:
the problem with attention is long running context. every token needs to attend to every other token. however unrelated it is.
the obvious solution to this is that you mask some of the tokens just like you mask tokens after the current token in a forward pass
which tokens do you mask though?
there are several ways to do it. the obvious one is to keep a sliding window of say 256 tokens and so the 256th token will attend to all 1-255 and 257th token will attend to all from 2 to 256.
this produces an obvious problem.
there is no way of keeping global context. all the models will only have the knowledge about the other tokens in it sliding window. that's exactly why some tokens are marked as global tokens and they take on the responsibility to attend to all tokens ignoring the sliding window
these global tokens are placed at regular intervals i.e. 0,512,1024 ...
this gives distributed long-range connectivity but there are obvious pitfalls to this:
1. the global tokens are overloaded
2. its still an expensive operation for every token to attend to all global tokens
3. the global tokens are purely sequential and hold no semantic meaning
so then what's the solution
replace global tokens with many global semantic anchors (R₁…R_M)
these are not tokens. they are trainable vectors (like 64–256 of them) of the same dimension as tokens.
tokens now attend to these anchors and not to all global tokens. these tokens hold semantic meaning such as summary structure, paragraph coherence
so then every token chooses its own top K anchors which is decided by
score(i, j) = Qᵢ · Rⱼ
where Rj is the anchor embedding which also improves with Q,K,V during pre training.
this reduces the cost of compute significantly while having extremely little reduction in performance.
super cool, right?