We’ve designed a new architecture, which replaces the hidden state of an RNN with a machine learning model. This model compresses context through actual gradient descent on input tokens. We call our method “Test-Time-Training layers.” TTT layers directly replace attention, and unlock linear complexity architectures with expressive memory, allowing us to train LLMs with millions (someday billions) of tokens in context. Our instantiations, TTT-Linear and TTT-MLP, both match or beat the strongest Transformers and Mamba.
Arxiv: https://arxiv.org/abs/2407.04620
---
Sequence models store historic context into a hidden state. RNN layers, like Mamba, compress into a state of fixed size across time. They’re efficient, but performance is limited by expressivity. Attention has a KV cache, which grows over time. This state doesn’t compress any historic context, but becomes costly as context length increases. Why don’t we compress context into the weights of a model – just like LLMs do with internet data? This “hidden state model” still has fixed size over time, but a lot more expressivity.
We use self-supervised learning to update the hidden state weights, taking a gradient descent per on each token. After forwarding on a sequence, the state has been “trained” on the tokens in its context window.
And remember, the hidden state lives in just 1 layer of the end-to-end architecture. The other components, like the QKV projection matrices, are learned during pre-training with the standard cross entropy objective. So, the end to end architecture is meta-learning the best way to compress context, such that it helps for next-token prediction. We are “Learning to (Learn at Test Time)”
If it’s hard to believe that this actually works, our paper does a great job explaining the method in depth. We also cover some theoretical equivalences – this process of self-supervised learning is actually equivalent to self attention (if the hidden state model is a kernel).
Our own instantiations, TTT-Linear and TTT-MLP, both match or beat the strongest Transformers and Mamba on perplexity performance. Plus, TTT-Linear is already faster than the fastest SSMs, and has strong scaling in size and context.
The search space inside this framework is huge, and our paper has only taken a baby step. All our code, including training loops and datasets, is available in JAX (https://github.com/test-time-training/ttt-lm-jax) and PyTorch (https://github.com/test-time-training/ttt-lm-pytorch)
---
Twitter thread: https://x.com/karansdalal/status/1810338845659131940