Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

September 2023

tl;dr: Kernel trick to approximate softmax attention allows matrix matmul with associative property and reduces quadratic attention complexity to linear, wrt to length.

Overall impression

Pain points of transformers: large memory, heavy computation. Actually the computation is not that heavy, and the computation is slow due to memory bandwith. The large memory increases the HW barrier for serving, and the heavy memory access pattern slows down the token generation.

Sparse attention or locality sensitivity hashing (LSH, Reformer) can reduce complexity from O(N^2) to O(N*sqrt(N)) or O(NlogN), but they do not speed up autoregressive inference.

When we talk about new transformer architecture or mechanism, we need to see how it improves training or inference performance. For fast transformers, it is linear complexity in computation and memory in training. Linear complexity in computation and constant memory in inference.

The order of matmul matters a lot! The softmax hinders the application of associative property of matrix. Linear transformer removes the necessity of softmax and replaces that with element-wise activation (elu + 1). –> This is the biggest contribution of this paper.

The self-attention is expressed as a linear dot-product of kernel feature maps. Then associativity property of matrix products is used to reduce the complexity to linear scale. This leads to 4000x speed up on autoregressive prediciton of iamges.

Key ideas

Technical details