Hacker News new | past | comments | ask | show | jobs | submit login
Mixture-of-Depths: Dynamically allocating compute in transformers (arxiv.org)
281 points by milliondreams 3 months ago | hide | past | favorite | 83 comments



I think more complicated routing is absolutely going to become more common.

Specifically, I think at some point we are going to move to recursive routing, ie. pass back through a set of experts again. In the future, 'chain-of-thought' will happen internal to the model recursively


We can name these hypothetical objects Recursive Neural Networks.


i know you're jesting but RNNs are recursive along the sequence length where I am describing recursion along the depth.


Recursive NNs are not the same as Recurrent NNs:

https://en.wikipedia.org/wiki/Recursive_neural_network

Well ish. The article above explains that Recursive-NNs are hierarchical whereas RNNs are linear. I guess the distinction is a little on the fine side.

Anyway carry on. Pedantic moment over.


The recursive neural networks described there are a failed academic project from more than a decade ago, predating modern deep learning. Basically everyone using the phrase recursive nn nowadays is probably just mispeaking for RNN. RNNs also are not linear


I don't know about "everybody nowadays" but I remember Recursive Neural Nets as an architecture introduced by Christopher Manning with the argument that it was better suited to the hierarchical structure of language than existing architectures. I did find it a bit of a bad choice of name, given that it's so closed to Recurrent Neural Nets. All this is from memory though I might check the internets later to see what I misremember.

RNNs are a large class of architectures of varying complexity, from Kallman Filters to LSTMs. It's not clear to me exactly what the wikipedia article means by "linear" but LSTMs for example treat their inputs as sequences and don't try to deconstruct them into parts, like e.g. Convolutional Neural Nets do. So maybe that's what's meant by "linear".


No opinion on the specifics of this distinction, but it's worth noting that in research, an awful lot of successful projects have their origins in failed projects of decades ago...


My experience working in machine learning academia is an overfocus on failed projects from the early 00s to 90s that really only stopped in 2020+.

We can often trace back successful projects to failed precursors, but often the people behind the successful project are not even familiar with the failed precursor and the 'connection to the past' only really occurs in retrospect. See the 'adjoint state method' and connections with backprop.


This is sometimes true, sure. And often the older work has more entered the general consciousness than being chased down by searching specific cites. On the other hand, very little is truly new, and recency bias can lead you into all sorts of back-eddy's.

Once the dust has settled, there are often much clearer through lines than in looked like at the time. It's hard to see when you are on the moving front though.


Depthwise RNN?


Like decode the next token, then adjust what you're paying attention to, then decode it again?


Isn't it the only way to, say,understand a pun?


That is exactly how LLM inference is performed, so I'm being cheeky (I'm 99% sure anyone proposing anything in this thread is someone handwaving based on limited understanding)


You would be wrong, but that is fine. Been working with attention since 2018.

Why assume I know little and leave snarky comments (and basically a repetition of the prior joke at that, subbing RNN for transformer)?


To playfully invite for you to participate in conversation further, so that I may humbly learn from you. "I don't know what you're talking about" seemed too spartan and austere and aggressive, and you reciprocated politely, if again sparsely, when the other person playfully invited you to elaborate.


Well, you've now made your original intent specific, but in case you didn't draw the requisite lesson I'll make that explicit.

Because text has less bandwidth than almost any other medium, certain forms of humor are much more likely to be understood (in this case, your "gentle playfulness" was taken to be snark, sarcasm, and point scoring).

If you insist on using this and similar forms of humor that, ordinarily, depend quite strongly on intonation to convey intent, you'll have to be much more explicit to avoid being misunderstood. You are going to have actually state your intent explicitly as part of your communication. This need not entirely destroy the humor, for example, you might try something like this:

And so I say to you (playfully, sir, playfully): etc.

Or this:

Yadda yadda yadda. (I kid, I kid!)

The Internet-native forms of this are the humble ;-) or the newer j/k, but I find that it is all too easy to overlook a 3-character sequence, particularly if the passage being so marked is even as long as a single paragraph, but they can serve their purpose when used for the commonplace one-liner.


"blah, blah, blah" can be an expression of scornful boredom or the utterance of a vampire.


You are painfully boring




What you describe here sounds a little like the line of work centered around Universal Transformers, which basically process the input embeddings through a single transformer block multiple times with a separate module deciding when the embeddings have been cooked enough and can be pulled out of the oven so to speak.

Even more in line with the idea of "experts" there's a paper from last year on Sparse Universal Transformers in which they combine a universal transformer with sparse mixture of experts, so it's up to the gating mechanism to decide which transformer blocks and in which order are to be used in shaping the embeddings.

This really isn't my specialty but from what I gathered these are tricky to train properly, and require more overall compute during inference to reach comparable results to their vanilla transformer counterparts. It's an interesting direction nonetheless, having an upper bound on the number of computation steps per token is, in my opinion, one of the major downsides of the classical transformer architecture.


I think the reason this hasn't been done is you have no way to decide how many recursions are necessary at train time.

And if you pick a random number/try many different levels of recursion, you 'blur' the output. Ie. the output of a layer doesn't know if it should be outputting info important for the final result, or the output that is the best possible input to another round of recursion.


Yes, I think training this model would be hard. Perhaps something akin to how MoEs are trained where you impose some sort of loss distribution to encourage equitable routing, but for recursion.


Look at the human brain for useful analogies?

The default mode network does recursive/looping processing in the absence of external stimuli and world interaction. Multiple separate modules outside of the network are responsible for stopping and regulating this activity.


You could just learn the right estimated number of recursions, also passing 'backtracking'/'state' information at the next nested level. Kind of like how state space models encode extractible information via a basis function representation, you could encode extractible recursion state information into the embedding. See also transformers that can learn to recognize n-deep balanced parentheses (Dyck-n languages)


I have been thinking about this topic for some time. It might be done using the energy of the token. If it's still higher than an energy limit, then process it again, and increase the energy limit. The energy could be computed using log-sum-exp: https://openreview.net/pdf?id=Hkxzx0NtDB


This is actually how EfficientNet trains, using random truncation of the network during training. It does just fine... The game is that each layer needs to get as close as it can to good output, improving in the previous activation quality.


Attention is basically routing, these other routing schemes put a less fine-grained choice for the model, which potentially makes it easier to train


How is attention basically routing?


It routes values based on linear combinations taken from the attention map.


But all of those values are created using an MLP with the same parameters, so there is no routing to different parameters.


You have to look at it as a sequence of time steps which can interact. You can implement this interaction in many ways, such as transformer, mamba, rwkv or mlp-mixer. But the purpose is always to allow communication across time.

You use three distinct linear projections, one for queries, one for keys and one for values. From Q and K you compute the attention matrix A, and using A you construct linear combinations from V. But depending on A, for example for a token V_i there might be input from two other tokens, V_j or V_k, so information is moved between the tokens.


Think of it like an edge flow matrix


That doesn't clarify it for me. The same parameters are being used for every layer for every token. Yes, there is this differentiable lookup in attention like in MoE - but routing is about more than just differentiable lookup, it is about selecting on parameters not state.


The trendline is definitely toward increasing dynamic routing, but I suspect it's more so that MoE/MoD/MoDE enable models to embed additional facts with less superposition within their weights than enable deeper reasoning. Instead I expect deeper reasoning will come through token-wise dynamism rather than layer-wise -- e.g., this recent Quiet-STaR paper in which the model outputs throwaway rationale tokens: https://arxiv.org/abs/2403.09629


There are already some implementations out there which attempt to accomplish this!

Here's an example: https://github.com/silphendio/sliced_llama

A gist pertaining to said example: https://gist.github.com/silphendio/535cd9c1821aa1290aa10d587...

Here's a discussion about integrating this capability with ExLlama: https://github.com/turboderp/exllamav2/pull/275

And same as above but for llama.cpp: https://github.com/ggerganov/llama.cpp/issues/4718#issuecomm...


See, this is where my understanding of LLMs breaks down. I can understand one token going through the model, but I can't understand a model that has different "experts" internally.

Do you have any resources or links to help explain that concept?


The "mixture of experts" goal is to add more parameters to the model to make it more powerful, without requiring any more compute. The way this is done is by having sections of the model ("experts") that are in parallel with each other, and each token only going through one of them. Think of it like a multi-lane highway with a toll booth on each lane - each car only drives on one lane rather than using them all, so only pays one toll.

The name "experts" is a bit misleading, since each expert ("highway lane") is not really specialized in any obviously meaningful way. There is a routing/gating component in front of the experts that chooses on a token by token basis (not sentence by sentence!) which "expert" to route the token to, with the goal of roughly load balancing between the experts so that they all see the same number of tokens, and the parameters in each expert are therefore all equally utilized.

The fact that the tokens in a sentence will be somewhat arbitrarily sent through different "experts" makes it an odd kind of expertise - not directly related to the sentence as a whole! There has been experimentation with a whole bunch of routing (expert selection) schemes.


It is still just one token going through the model.

I actually think mixture-of-expert is a bit of a misnomer, the 'experts' do not really necessarily have super distinct expertise. Think of it more as how neurons activate in the brain - your entire brain doesn't light up for every query, now in neural networks the same thing happens (it doesn't fully light up for every query).

Don't really know a resource besides the seminal Noam Shazeer paper, sorry - I'm sure others have higher-level.


Most of the original MoE implementations around LLMs were in fact recursive


Could you please elaborate?


The original MoE research done by Google around LLMs involved nested transformers to scale them. It was a layered approach where at each layer you would have set of experts, generally routed to by simple heuristics, then each of those models would call into its own series of experts and combine the data in various ways.

These models were SOTA for their time


Interesting, but that isn't recursive as the sub-model cannot invoke a model higher up in the invoke graph/tree.


Most important paper of 2024.

The idea that we want models not to have to use the same amount of compute for every token has been around for a while. This is the first compelling mechanism I've seen for doing it.

> Equipped with these new methods, we can sample autoregressively by choosing to route tokens to or around a block based on the router’s output, which does not depend on any information from future tokens. We provide empirical evidence that this is a relatively easy auxiliary task that quickly achieves 99% accuracy.

Does anyone else find this is a bit surprising?


Sparse Universal Transformer is older and already did routing-based early termination...


Most important? The idea that not every token needs the full context window should be an obvious optimization.


that’s not the idea here


Simplified Intro Version:

Imagine you have a smart assistant that can understand and process the words you say to it. Usually, this assistant pays equal attention to every word you say, no matter how important or unimportant each word is to the overall meaning of your message.

Now, imagine that we found a way to teach the assistant to be smarter about how it uses its "brain power." Instead of giving equal attention to every word, the assistant learns to focus more on the words that are most important for understanding what you mean. It can even adjust this focus on the fly, paying more attention to different words depending on the context of your message.

To make sure the assistant doesn't get overwhelmed, we also set a limit on how much total "brain power" it can use at any given time. It's like giving the assistant a budget and saying, "You can only spend your brain power on a certain number of words at a time." The assistant then has to decide which words are most important to focus on.

Even with this limit, the assistant is still flexible in how it uses its brain power. It might spend more on certain words and less on others, depending on what you're saying. This means that while we always know the total amount of brain power the assistant is using, it can adapt to different situations and prioritize what's most important.

When we teach the assistant using this method, it not only learns to focus its attention intelligently but also does so very efficiently. It can understand you just as well as an assistant that pays equal attention to every word, but it uses less brain power overall. This makes the assistant much faster at responding to you and processing new information.


I understand this is ELI5, but doesn’t attention already do this, in the way you described? It pays specific focus to the most contextual words in the prior sequence.


Not from a computational perspective. To calculate the attention score you have to calculate every token against every other token. That is quadratic. Every article like one, the, a, etc will have to be calculated against every other word even though they are only revelvant within a short distance of the word they are attached to.


Isn't that factorial, and much more costly than quadratic?


N choose 2 = N! / 2!(N-2)! = N(N-1) / 2.


The way I understood it is that for each token, the attention mechanism itself consumes a fixed amount of processor time.

The innovation here is to prioritize tokens so that some tokens have more or less processor time.


I wrote up a bit about it here, from what I could piece together:

https://lifeinthesingularity.com/p/googles-breakthroughs-in-...


Nice writing. Reminds me of New Scientist style. (I like NS so that is a compliment). I think the “explain as you go along but be brief style”. Which is nice for getting a feel for the space.


hey thank you!

i try to operate at zero-basis and quickly scaffold to a simplified model so ANYONE can grab the "why this makes nerds say wow" factor haha


It’s very similar to Mixture of Experts. But instead of routing tokens to multiple experts, you "deploy to a single expert which can be dynamically skipped"


Mixing these would be pretty cool. Further reduced compute for MoE while keeping the performance.


In the paper they already show a mixing of these two with Mixture-of-Depths-and-Experts (MoDE).


"This is more computationally efficient than performing a full content-based lookup across an entire memory buffer for each step in the future, and could be one step towards drastically increasing the context-length available for making a prediction."

Is this how they get a context window of 10 million tokens? Or are they refering to even longer context windows in the future?


After trying to understand and implement some algorithms in RASP [1, 2], my take-way was that certain functions need a certain amount of transformer layers to operate. Following this logic, it should become apparent that the functions learned by transformers can be spread over multiple heads. Repeating these functions might be very valuable for understanding and solving a problem, but current inference does not allow (a set of subsequent) heads to be repeated. This paper indeed seems a promising direction.

[1] https://arxiv.org/pdf/2106.06981.pdf

[2] https://www.youtube.com/watch?v=t5LjgczaS80


Maybe the only downside to how fast LLMs are moving is papers come out faster than anyone (not at Google) can train and test the improvements.

I got into deep learning around when ReLU and dropout was hot and on my consumer 1080 I was able to change one or two lines of code and test the improvements in a few hours, whereas now, I guess I'll need to wait a few weeks for mistral et al to try it out


Welcome to the GPU poor!

I'm focusing in quantization approaches and testing on my obsolete last gen GPUs.


The funny thing is, I have 8 3090s which last epoch would have put in like - top 1% of compute. Now, still a lot of compute but pales in comparison to the 100x H100 GPU clusters we're seeing today.


hu-po does in-depth live-stream reviews of AI papers.

highly recommended, here is his take on the mixture-of-depths paper discussed. https://www.youtube.com/watch?v=Teru_qIdB8Y


for just a moment, i thought you were referring to the huffington post


The abstract and the rest of the paper don't really match imo. It's not really allocating more to some sequences, but just introducing ~dropout. Might be different sides to the same coin, but was still a weird read.


We spent a fair bit of effort ensuring we were accurate with the language and claims, so we're happy to take any feedback and make updates in subsequent versions. However, I don't see where we claim that MoD allocates more to some sequences and not others (specifically, the abstract says "transformers can instead learn to dynamically allocate FLOPs (or compute) to specific positions in a sequence".

That said, it's a pretty simple change to make the approach work in the way you describe (allocating more to some sequences and not others) by changing the group across which the top-k works. In the paper we use the time (sequence) dimension, but one could also use the batch * time dimension, which would result in asymmetric allocation across sequences


Dropout is at train time this is at inference time. Dropout is random this is determined. Can't compare them.


Essentially the second law of thermodynamics for neural networks.

Neat!


Can you elaborate on this analogy?


Fixed net resource allocation but non-uniform distribution.


It's a start but it's disappointing that half the layers still have to process every token. It seems like we ought to be able to get to 90% or even 99% savings when these models currently allocate the same compute for outputting "the" as they do for outputting the first digit of the answer of a complicated math problem.


Speculative decoding does this to an extent - using a smaller model to generate its own predictions and putting them in the batch of the bigger model until they diverge

https://huggingface.co/blog/whisper-speculative-decoding


It doesn’t. It simply trades compute efficiency by transposing matrix multiplications into “the future.” It doesn’t actually save FLOPs (uses more) and doesn’t work at large batch size


>doesn’t actually save FLOPs (uses more)

Does anyone even care? Really, who cares? The truth is nobody cares. Saving FLOPs does nothing if you have to load the entire model anyway. Going from two flops per parameter to 0.5 or whatever might sound cool on paper but you're loading those parameters anyway and gained nothing.


companies that run these things care - they run at huge batch size and are compute bound in the limit


Are we going to hit bullseye?


This only cuts compute by “up to” 50% and only during inference. Quadratic dependence on context size remains, as do the enormous memory requirements. For something to be considered a bulls eye in this space it has to offer nonlinear improvements on both of these axes, and/or be much faster to train. Until that happens, people, including Google will continue to train bog standard MoE and dense transformers. Radical experimentation at scale is too expensive even for megacorps at this point.


Makes opportunities for smaller companies to innovative/experiment to offer solutions / acquisition targets where tighter inference compute requirements makes or breaks the experience but larger training cost is less of a concern (such as embedded or local runtime use cases)


Before those opportunities are available to you, someone would need to spend a few million dollars and train a competitive model with this, and then release it under a license that allows commercial use. This is out of reach for the vast majority of smaller companies. These models only excel at large parameter counts, even for narrow problems. This is especially true in the case of MoE, which is a way to push the overall parameter count even larger without lighting up the whole thing for every token.


Yeah all attempts at reducing complexity from quadratic to linear failed, only Mamba still has a chance, but it's not tested on large models and only provides a speedup at for 2000+ tokens. That was to be expected as small sequences have very small memory requirements for transformers, but recursive architectures use the same hidden size. So when recurrent hidden size > sequence length, the old transformer is faster.


It's more subtle than that IMO. They haven't necessarily "failed" - they just don't have the "superpowers" that the metrics used to evaluate such systems are aimed at. E.g. no such linear method devised so far (that I know of, at least) is able to do very high recall point retrieval in long context _and_ effective in-context learning simultaneously. You get one or the other, but not both. But as far as the metrics go, high recall retrieval in long context is easier to for the researcher to demonstrate and for the observer to comprehend - a typical needle/haystack setting is trivial to put together. It is also something that (unlike in-context learning) humans are usually very bad at, so it's perceived as a "superpower" or "magic". In this case e.g. Mamba being more human like due to its selective forgetfulness is currently playing against it. But whether it's "better" per se will depend on the task. It's just that we do not know how to evaluate most of the tasks yet, so people keep trying to find the proverbial keys under the lamp post, and measure what they can to make progress, and thereby keep their efforts lavishly funded.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: