Motivation

For a while my knowledge of ML was limited to what I’ve learned in school: perceptrons, gradient descent, perhaps multiple perceptrons grouped into layers. Looking at ML landscape from afar I couldn’t follow how many fundamentally new ideas were developed. Conference papers are often written in a way that presents the idea, but not the intuition or the impetus for exploring that particular direction. Looking at the attention paper I was quite lost: why do we need all of and and ? What is their intuitive explanation? Why this direction is being explored at all?

Reading further did not make it simpler, with many new concepts introduced at once. Flash attention seemed like an indecipherable rewrite. Mamba was voodoo magic.

For a long while, I wanted a blogpost to explain the reasoning, motivation, and perhaps more importantly, the motivation and some mathematical foundation behind the recent advances in large language models, mainly from the architectural point of view. Perhaps it’s trite to write one on the subject, given there are so many, but I think there’s still some insight and intuition to offer. We’ll start with quite a detour, with the goal to get an intuition for the development of new approaches and what problems are they solving

Starting Point

So what are we even doing? The general problem formulation in ML is usually trying to learn a function

Early ML problems started with classification problems, where . For a linear , the function can be fully captured by a matrix of weights . From PAC theory (Probably Approximately Correct Learning, Haussler et al. 1992) we know that such a matrix can be learned using polynomially many samples, subject to sample distribution constraints. To find the weights, we perform gradient descent, and (usually) given enough good data, we get a good classifier. The study of various optimizers is a subfield in itself, but for the purposes of this blogpost we’ll only note that we are doing a first-order optimization, since second-order methods are quadratic and intractable for all but simplest cases.

Deep Learning

Most functions are not linear - and we need some way to learn them, ideally still using matrixes. Stacking multiple matrix multiplications would not work, as a composition of matmuls is still a matmul, and some non-linearity is needed. But what if we insert a simple non-linear function with few or zero parameters in between linear layers? That gives us a key idea of deep learning - and turns out that’s enough to learn almost any reasonable function! E.g. from Cybenko’s 1989 “Approximation of Superpositions of a Sigmoidal function” we know that even a single hidden non-linear hidden layer gives us enough expressivity to approximate any continuous function on a compact domain to arbitrary accuracy. Stacking many of such layers puts “deep” in deep learning and allows us to solve complex problem such as image classification.

Handling Sequences

In the previous section we’ve assumed that both input and output of our classifier are vectors of a fixed length, but that limits the set of problems we can tackle. In 2014 “Sequence to Sequence Learning with Neural Networks”, Sutskever et al. considered a translation problem: our input is a sequence of words (logits, really: tokens embedded into a -dimensional space, but here we use the terms “tokens” and “logits” interchangeably) of arbitrary length, and output is also a sequence of words of different length. How do we encode it? Their suggestion was using RNNs (Recursive Neural Nets) to first encode the sequence to a hidden state (encoder), and then use an RNN decoder to deconstruct the hidden state to the output sequence.

The idea of RNNs (e.g. Hinton, 1986 “Learning Internal Representations by Error Propagation”) is simple: the RNN net does not just consume input and return output, but instead consumes input and “hidden state”, and returns output and “hidden state”. Then we can compress the input sequence in a hidden state by successfully “feeding” the net a sequence of inputs and the resulting hidden state after each, forming the encoder:

Seq2Seq Encoder (only final state matters)
==========================================

Input:   x₁  →  x₂   →  x₃  →  x₄
         ↓      ↓       ↓      ↓
Hidden: [h₁] → [h₂]  → [h₃] → [h₄] ←  context vector
         ✗      ✗       ✗  
      (discard intermediate outputs)

Similarly by repeatedly giving the “decoder” net the hidden state and it’s own output as input we can “decode” the hidden state into an output sequence (known as “autoregressive decoding”):

Seq2Seq Decoder (generates all outputs)
========================================

Context:  h₄ (from encoder)
          ↓
         [d₁] → [d₂] → [d₃] → [d₄]
          ↓      ↓      ↓      ↓
Output:   y₁     y₂     y₃     y₄  ← generated sequence

The method above suffers from two issues:

  1. Scalability - Training is not parallelizable across sequence length, as large sequences are processed one token at a time.
  2. Forgetting - We are compressing all the input to a single hidden state, and the compression mechanism is not perfect. Over time, new memories displace old ones - and in case of simple RNNs that means very quickly.

Long Short Term Memory

LSTMs (Long Short Term Memory, Hochreiter et al., 1997) were initially proposed to solve the vanishing gradients problem in vanilla RNNs, but they additionally alleviate the forgetting problem by being more selective in the information saved, forgotten, and returned at each step. LSTMs add input, forgetting and output gates to regulate what is being added from input, what is being removed from hidden state, and what is being transferred to output, respectively. These gates are calculated using learned weights and bias from respective input and hidden state as

where is a sigmoid, and “f” stands for “forgetting” (switching subscript to and yields formulas for input and output gates respectively).

Attention

Attention operator bypasses the scalability issue of RNNs by processing all input tokens in the sequence at the same time, without the serial encoding to a hidden state, at the cost of the quadratic computation time. The idea of “attention” was introduced in a 2014 paper by Bahdanau et al., and then extended to perform only attention in “Attention is All you Need” (Vaswani et al., 2017), but here we would like to give a different explanation, largely following the ideas of Tsai et al., 2019 and Bactra’s Notebook , which represent attention as a kernel smoother.

The formulation of self-attention is well-known now: we have a sequence of input logits , where is the sequence length, and each where is the embedding dimension. Using learned matrixes (intuitively denoting query, key and value) we obtain matrixes by calculating , , and correspondingly. And of course the attention is given by

which allows attending to all the previous tokens in parallel without lossy compression of hidden state, fixing the scalability issues of previous RNN-based architectures. To try to visualize what’s going on in the attention, we can first visualize the term :

The result of this matrix multiplication would be an matrix where each row is the entire token sequence weighted (or attended to) by every other token in the sequence (thus the attention computation requires quadratic number of computations in terms of the size of the input sequence).

We then take a row-wise softmax over this matrix, resulting in each row representing probability distribution and summing to one. Finally, multiplying by projects the result out to the output space (again we have a stacked matrix of logits).

Kernel Smoothing Formulation

Now let’s take a bit of a detour to see the intuition behind this calculation. Consider a problem of trying to “smooth” a function : as in the picture below we have a set of points for which the function output is known, and we are trying to figure out for a new point .

A reasonable guess could be taking a nearest point and reusing it’s Y-coordinate. That’s not very generalizable, so we could take K nearest neighbors and average their Y-coordinate instead. That’s still not very general - what if we take all the points we have observed so far and weigh the contribution of each datapoint by it’s X-coordinate distance to (or to paraphrase, how much attends to )? The 1964 result by Nadaraya and Watson proposes the following kernel estimator which estimates the value of a function at an arbitrary point given the observed values so far:

Where is a “kernel function” (total integral has to sum to 1). Now we can see how the attention formula re-emerges: in the attention formulation is exponential of the inner product of query vector and pre-existing keys , and are the values the pre-existing keys are mapped to (of course their actual values have to be learned).

Multi-Head Attention

Attention computation can be seen as “enriching” the -dimensional representation of each token with contributions from every other token, so that we can predict the next token from that representation alone. But embedding dimension might not be large enough to represent all that information well, and different relationships between tokens could be interfering with one another. To counter that, Vaswani et al. propose multi-head attention: doing attention computations in parallel with different weights, and then stacking them together, before projecting to the final output. In terms of kernel smoothing it can be seen as taking multiple kernels and doing a weighted average over the corresponding results.

Flash Attention

Tri Dao et al. 2022 result on rewriting attention with a single efficient CUDA kernel is well known, but we think extra intuition could be beneficial. With context length usually being a very large value (tens or even hundredth of thousands in contemporary models), and being relatively small (could be even less than a hundred), materializing the entire matrix in HBM memory on modern GPUs quickly becomes the bottleneck - as GPUs usually have far more flops than they have memory capacity.

The usual answer for trading off compute for performance is “kernel fusion” (the “kernel” here is not connected to the smoothing definition of a “kernel”) - we compute the outer matmul first, and whatever intermediate results are necessary, we compute them on SRAM.

Matrix Multiplication Fusion on a GPU

To visualize that, consider first the matmul between the softmax result and , and what regions are necessary to calculate the full tile of the output at coordinates :

During the kernel execution, shaded regions would be read into the kernel SRAM, in tiles of the size of the double shaded region (SRAM has to be larger than all double shaded regions combined).

Now let’s ignore softmax for simplicity, and consider that we would like to fuse the computation to avoid materializing giant matrix (since HBM reads are expensive relative to flops on modern accelerators). Let’s see what regions would we need to read:

Given that is reasonably small, the intermediate matrix multiplication can be fused - but how do we deal with the softmax applied to the intermediate matrix? Recall that (row-wise on a matrix) softmax is a function which enforces the probability distribution properties on a vector: larger values stay larger, and the sum of all values becomes . The mathematical formula is:

Which works for real numbers, but overflows quickly once applied to float, so an updated formula is used, which first divides all numbers by the largest number in the sequence:

The formula requires knowing the result of the entire row in advance and is difficult to apply in fusion directly. Tri Dao et al. propose online softmax rewrite which doesn’t require materializing the entire matrix, by calculating intermediate maximums for each tile, scaling by those, and cancelling them later, all while keeping track of the overall maximum on the side and then scaling the overall result by it (intermediate maximums are used to avoid overflows in intermediate computation). The online softmax trick actually goes earlier than that, being first explicitly mentioned in “Online Normalizer Calculation for Softmax” by Milakov et al. 2018, and the idea of online rewrite goes back further still, to Welford 1926, calculating variance in an online manner. The idea is to calculate softmax on each chunk separately, and then divide by the intermediate calculation and multiply by the corrected value (the reason why we don’t just use the corrected value at the end is to avoid intermediate overflows).

Mamba and State Space Models

In “Efficiently Modeling Long Sequences with Structured State Spaces” (2023), Gu et al. represent the state as a recurrence relation motivated by linear ODEs:

where is input number , and are the corresponding output and hidden state respectively. This is an RNN, but with a crucial difference: the updates are now fully linear. Recall that the principal RNN limitation attention set out to solve was lack of parallelizability: training required applying updates one by one, as previously (e.g. Elman et al., 1990) RNNs made use of a non-linear activation function :

Absence of a non-linear activation function makes a sequence of recursive updates linear with respect to the hidden state, and the sequence of hidden states becomes

which can be computed in parallel using FFT. The idea is further developed in “Mamba: Linear-Time Sequence Modeling with Selective State Spaces” (Gu et al., 2023), where the matrixes and are not fixed but are themselves learned as functions of inputs and time. In that case the FFT rewrite is not applicable anymore, but the linearity and associativity can be still exploited to perform the computation in parallel using the parallel scan rewrite.

Diffusion

Diffusion models (Sohl-Disckstein et al., 2015) generate entire images at once, by learning the denoising process, potentially from a given hidden state. For this blogpost we wanted to highlight symmetry between RNNs/attention for input (one can consume sequences token by token with an encoder, one can attend to the entire input at once), and autoregressive decoding (generating output token by token) and diffusion (generating the entire output at once).

Conclusion

Attention allows training parallelization and representation of the whole computation history, but the quadratic computation cost imposes fixed sequence length limitation. As complexity of computation growth quadratically with sequence length, full self-attention becomes intractable for large context length: this can be seen as computation/reasoning existing only within fixed number of input/output tokens and disappearing afterwards.

Many RNN-like approaches (linear attention from “Transformers are RNNs”, Katharopoulos et al., 2020), Mamba and others were proposed to overcome the quadratic behavior while preserving the parallelizability of training, usually using optimized parallel associative scan rewrite. Ultimately they work by condensing the previous information to “hidden state” passed to the next computation.

We hope the exposition was useful to get some intuition behind some methods which may seem magical. Three ideas were explored: attention as kernel smoothing, flash attention as matrix multiplication fusion, and Mamba architecture as an RNN. In addition the symmetry between encoder/RNN and autoregressive/diffusion was highlighted.

References

Acknowledgements

Thank you to Mikhail Trofimov and Julia Proskurnia for giving feedback on earlier versions of this post.