Self-Attention
How it works
For input X, three matrices are computed: Q = X*W_Q, K = X*W_K, V = X*W_V. The result is Attention(Q,K,V) = softmax(QK^T / sqrt(d_k)) * V. Dividing by sqrt(d_k) prevents excessively large dot-product values. In Multi-Head Attention, the process runs in parallel across h independent heads, and the results are concatenated.
Problem solved
Recurrent neural networks (RNN, LSTM) process sequences step by step, making it difficult to model long-range dependencies and preventing full parallelization of training.
Implementation
The n x n attention matrix requires O(n^2) memory, making it prohibitive for sequences longer than ~4k tokens without approximations like FlashAttention.
Self-attention is permutation-invariant; without explicit positional encoding, token order is invisible to the model.
Without dividing by sqrt(d_k), dot-products grow large, pushing softmax into near-zero gradient regions.
Evolution
Bahdanau et al. introduced soft attention for neural machine translation, precursor to self-attention.
Vaswani et al. propose self-attention as the sole mechanism, replacing recurrence entirely.
Multiple works propose sub-quadratic approximations to full self-attention for long sequences.
Dao et al. introduce FlashAttention, achieving 2-4x speedup via tiled computation without approximation.
Technical details
Computational complexity
Time complexity: O(n^2 * d). Space complexity: O(n^2 + n*d).
Compute bottleneck
Computing and storing the n x n attention matrix is the primary bottleneck for long sequences.
Execution paradigm
Parallelism
Hardware requirements
Matrix multiplications for Q, K, V projections and attention score computation are highly optimized on GPU tensor cores.
TPUs are optimized for large matrix multiplications present in attention computation.