Upshot of the paper -- right now KV caches are implemented for multiple deep layers in LLMs. Why not just the top layer? It would save memory.
Initial result -- those KV caches in lower layers matter, and output suffered.
Updated plan -- cull half the KV layers! This works 'nearly' as well as keeping all of them, with memory and compute savings.
Downside - triple the training, worse out of band / long context performance.
This feels to me like a technique you'd use on a particular architecture deployed at the edge where compute matters and you have a little extra room on performance. Phi-3 on raspberry pi, basically.
Interesting! As always, I wish models showed prompt output in their papers, not just perplexity numbers. But, here we are.
I wish papers could be accepted with negative results. There's a lot of value in not repeating the same mistakes specially in a field like deep learning, which is not an exact science and mostly just driven by intuition.
From what I understand, not quite. It looks like the cost of training might be similar, but less parallelisable within a specific token sequence. This is because they have to compute the KV of token T before they can use it in T+1 whereas in a regular training process you can compute the KV at each layer for every subsequence. You're right that it took 2.7x longer to train the smallest model but I wouldn't be surprised if the GPU utilisation was proportionally lower too.
It's not an outlier. You will see more than 26x improvement if you try this on an even deeper LLM with more layers. The deepest model they have applied it on has 30 billion parameters.
Edit: I apologize. The table was cut off on mobile and I didn't see that they sneaked in GPU+CPU offloading for the 25x result.
Sure; but 40% improvement is much less than a 26x improvement. If 40% is the realistic figure, cite that. Changing the title to include an outlier of 26x is click baity.
LLM inference optimization has been key for the OpenAI GPT-4o presentation (2x faster, 50% cheaper) and its driving lots of industry research because it’s direct cost savings, but it’s refreshing to see so many techniques published as papers (i.e from Stanford, Berkeley…)
You have to run the entire model once for every token. To generate the first token you have to process the entire preceding context once. On the second token you don't need to recompute the context, you can just take what you have calculated in the previous pass. For attention you need to take the current token and run a pairwise computation against all the cached values so that you know which tokens relate to the current token.
This is unavoidable by the way. Any method to avoid quadratic attention will necessarily have to degrade accuracy, because less than quadratic means you will have to look at "less than all the tokens" in a given pass. You're bound to miss at least some of them. When you consider how simple the classical attention mechanism is, you realize that there is not much you can do. Any preattention pass is necessarily going to be more complicated than your quadratic attention mechanism.
What they do is just get rid of it entirely after w layers. This saves memory consumption, which allows you to fit more context in the same amount of VRAM. In the paper they decided to run more batches in parallel, but they could have also advertised a massive increase in context length. 128k context needs 16GB on my machine for llama3 7b. Their approach would allow at least 784k context in the same 16GB.
The KV cache is just another tensor to be used with matmuls. Unlike the model weights which are fixed, the KV cache is uniquely constructed for every input. Think of it as the model growing new weights to represent the new knowledge it learns about the user's input at inference time because not everything can be baked into the pretrained model.
You want to store your KV cache in the same processor that does the rest of your matmuls.
Initial result -- those KV caches in lower layers matter, and output suffered.
Updated plan -- cull half the KV layers! This works 'nearly' as well as keeping all of them, with memory and compute savings.
Downside - triple the training, worse out of band / long context performance.
This feels to me like a technique you'd use on a particular architecture deployed at the edge where compute matters and you have a little extra room on performance. Phi-3 on raspberry pi, basically.
Interesting! As always, I wish models showed prompt output in their papers, not just perplexity numbers. But, here we are.