It's pretty damn simple: a linearized transformer is a "slow" neural net whose outputs determine the weights of another ("fast") neural net.
It's a NN that can use tools, subject to one major restriction: the tool must be another NN. The restriction is is "the trick" that lets you backpropagate gradients through both NNs so you can train the slow NN based on an error function of the fast NN's outputs.
The only difference between linearized transformers and the kind that OpenAI uses is adding a softmax operation in one place.
Did you mean to paste a different image? That diagram shows a much older design than the transformer coined in the 2017 paper. It doesn’t include token embeddings, doesn’t show multi-head attention, the part that looks like the attention mechanism doesn’t do self-attention and misses the query/key/value weight matrices, and the part that looks like the fully-connected layer is not done in the way depicted and doesn’t hook into self-attention in that way. Position embeddings and the way the blocks are stacked are also absent.
The FT article at least describes most of those aspects more accurately, although I was disappointed to see they got the attention mask wrong.
Did you click the right link? The words "query", "key", and "value" are in the image! For the rest, you'll want to read the paper: https://arxiv.org/abs/2102.11174
Embeddings were around long before transformers.
The image only depicts a single attention head, of course.
Ah, I didn’t notice the picture came from Jürgen Schmidhuber. I understand his arguments, and his accomplishments are significant, but his 90s designs were not transformers, and lacked substantial elements that make them so efficient to train. He does have a bit of a reputation claiming that many recent discoveries should be attributed to, or give credit to, his early designs, which, while not completely unfounded, is mostly stretching the truth. Schmidhuber’s 2021 paper is interesting, but describes a different design, which while interesting, is not how the GPT family (or Llama 2, etc.) were trained.
The transformer absolutely uses many things that have been initially suggested in many previous papers, but its specific implementation and combination is what makes it work well. Talking about the query/key/value system, if the fully-connected layer is supposed to be some combination of the key and value weight matrices, the dimensionality is off (the embedding typically has the same vector size as the value (well, the combined size of values for each attention head, but the image doesn’t have attention heads) so that each transformer block has the same input structure), the query weight matrix is missing, and while the dotted lines are not explained in the image, the way the weights are optimized doesn’t seem to match what is shown.
Token embeddings typically only attend to past token embeddings, not future ones.
The reason is to enable significant parallelism during training: a large chunk of text goes through the transformer in a single pass, and its weights are optimized to make its output look like its input shifted by one token (ie. the transformer converts each input token to a predicted next token). However, if the attention weights attended to future tokens, they would strongly use the next token they are given, to predict that next token. So all future tokens are masked out.
Before I came across that image (and the paper Linearized Transformers are Secretly Fast Weight Programmers) I very much did not understand transformers. I had spent at least 20 hours with Attention Is All You Need, and probably another dozen hours with https://e2eml.school/transformers.html and wasn't getting much of anywhere. I have no ML/AI background; just a typical undergraduate-CS-level familiarity with neural networks -- the basic stuff that hasn't changed since the 1990s. I do have some experience with linear algebra, but that isn't the hard part of any of this.
Frankly, most people who publish in this field go out of their way to obfuscate the key insights. Mediocre physicists (but not the truly brilliant ones) do the same thing. It's very annoying.
Different things click for different people; for me, the image you linked in the initial post of this thread… it makes as little sense to me as the Time Cube image or the widely mocked Pepsi rebrand document from 2008.
"That it's objectively true" isn't enough of an advantage, sadly.
> Frankly, most people who publish in this field go out of their way to obfuscate the key insights. Mediocre physicists (but not the truly brilliant ones) do the same thing. It's very annoying.
I certainly sympathise, but I think the overlap here is "maths is hard to communicate", and the conflict between rigorously explaining vs. keeping everything simple enough to follow.
I might understand some of that, but couple extra questions if you don't mind answering: There's the key value pairs in the middle if I understand correctly, but what do the points on the last layer stand for? Are the key value pairs a vector and the last layer an actual token in a simple LLM, for example?
Personally I think the key/value/query terminology is kind of crappy, but at this point we're stuck with it.
Think of the "fast" network as being like a hashmap. It stores key-value pairs. You feed it a query, which is a key that you're looking for, and it gives you back the corresponding value. Unlike the hashmaps you're used to, however, this is sort of a "blurry hashmap". If you ask for a key you'll get sort of a blend of the values nearby it -- even if the key isn't in the hashmap.
> Are the key value pairs a vector and the last layer an actual token in a simple LLM, for example?
Almost.
Firstly, the image I linked depicts only one attention head in one layer of an LLM. So you have to imagine that replicated several times in parallel, and then repeated several times serially. The Fast Weight Programmer paper does a better job of explaining the fundamental unit; to see how they are replicated and repeated the block diagrams in Attention is All You Need are easy to understand.
Secondly, getting back to the "hashmap" analogy, you might imagine that the hashmap key-type and value-type are LLM tokens, but in fact that's not quite how it works. There are two other types. One is the embedding space and another is the key space, and the hash map keys and values are both of that type. Yeah this is why the key/value terminology really sucks. There is a feed-forward layer that translates input tokens into the embedding space, another that translates embedded space to key space (before the hashmap lookup) and a third one that translates hashmap-outputs back from key space to embedding space. I wouldn't get too hung up on all of these; it's sort of like if you learn a foreign language but still "think" in your native language... the embedding space is sort of like the LLM's native language that it invents for itself. The whole business with the key space is mostly sort of a hack... in order to have multiple attention heads and train on current GPUs efficiently you want the embedding space to be NUM_HEADS times larger than the key space. So the whole key space not being the embedding space is mostly just a kludge to make this work well on GPUs. Unless you're planning on writing your own LLM from scratch it's mostly safe to pretend that the key space and embedding space are the same thing. In practice people have tried taking apart small transformers and the key space appears to work like "parts of speech"... when it reads a sentence like "I walked the dog" the attention heads tend to create mappings vaguely similar to I=subject, walked=verb, dog=object, and then later on the slow network will produce queries like "what verb was applied to the dog?". This is an oversimplification, but it gives a general idea of what's going on.
Also if you like maths, equations (4)-(19) from https://arxiv.org/abs/2102.11174 are a really spectacular example of using math when it's the right way to explain something, instead of using it to hide things.
https://people.idsia.ch/~juergen/fastweights754x288.png
It's pretty damn simple: a linearized transformer is a "slow" neural net whose outputs determine the weights of another ("fast") neural net.
It's a NN that can use tools, subject to one major restriction: the tool must be another NN. The restriction is is "the trick" that lets you backpropagate gradients through both NNs so you can train the slow NN based on an error function of the fast NN's outputs.
The only difference between linearized transformers and the kind that OpenAI uses is adding a softmax operation in one place.