Nice! For what it is worth, a colleague and I made a library a while ago that factors out most shared model code, with which many models can be implemented in about 100 lines (excluding Python import ceremony and comments). E.g.:
Nice. I will definitely be taking a look at this. Have you looked at the xformers library ? They are looking at the same problem as you but their focus is more on providing performant transformer modules using triton. Using specific components from the library though is not as simple. I kept running into runtime errors so I've kept it aside for now. I am building something based on the Bert architecture so I will give this a look. Thanks for all the work!
I would've loved to look at xFormers, but I avoided looking at other implementations to make sure that ours is a clean room implementation.
Curated Transformers started as a very small library just for spaCy (spaCy 3.7 transformer pipelines use Curated Transformers) with just the older encoder models (BERT, RoBERTa, etc.). spaCy used Hugging Face Transformers prior for the provided transformer models, but we wanted something where we could easily hook into different parts of the model (e.g. for distillation).
After the functionality needed for spaCy was done, Matt @ Explosion encouraged us to extend it into a more general PyTorch library that would also support decoder architectures, generation, etc.
Fortran! If you don't mind me asking, why Fortran?
I know it underpins a lot of time-tested scientific code, often wrapped by libraries like PyTorch and Numpy, but Fortran isn't exactly a popular language nowadays. What's your rationale for using it?
Tdlr, Fortran is low level-ish, compiled, but otherwise almost identical to numpy syntax wise.
It supports all the common array and matrix operations and it doesn't need memory and pointer management the way C does. But it still compiles down to something very fast, you can link in BLAS and GPU libraries, supports easy parallelism...
When I compare with e.g. Karpathy's llama2.c, I think Fortran is easy to work with implementing basic transformer inference because of how it handles arrays.
The downside is that while there are efforts to modernize it, I find it more cumbersome for non-numerical stuff, particularly strings. But I think for the actual linear algebra implementation, it can't be beat.
I should add, I know it's a bit of an uphill battle, I expect fewer people will use code that I write in Fortran vs basically anything else. But I'm hoping to pull some people in and get a critical mass of interest because I think it has a lot of promise. That's actually one of the reasons I wanted to get a Mamba implementation quickly (though now that there's a basic python one I think I'll have lost some potential users to it :)
things I'd like a non-ML-researcher explanation of about Mamba:
1. what is the overall insight of state space models beyond transformers? (i know this is somewhat covered in the paper but still a bit inaccessible)
2. what was the incremental innovation/result that is making Mamba more successful/interesting than its predecessors? (S4, H3, Monarch etc)
3. what are the implications beyond subquadratic scaling of context? say if i don't really care about context length > 100k tokens. what other benefits are there - for example, is Mamba potentially more compute-efficient to train for a similar size of model/dataset?
just offering 3 prompts for knowledgeable people to drop some alpha
My IQ is orders of magnitude lower than the authors of the paper, but I did my best to work through it anyway. I studied CE and have the basic control theory background and undergrad level discrete time systems intuition. It would take much additional studying to understand state space models enough to really parse this paper. But I tried anyway. Take my comment here with a big grain of salt.
The overall insight of Mamba is to solve a longstanding problem with state space models. They are good at compressing the input context, but the compression of input into a hidden state erases information needed to make use of the context effectively as Transformers do.
Their solution to this problem is to create what they call a selection mechanism. The mechanism is input-dependent, allowing the model to adjust its output at each step as the input changes. How they do this is by making a few of the state space variables input-dependent instead of input-invariant. They choose a few of the state space variables and attach linear layers and such to project the input onto the state space variable at each time step. The linear layers (etc) are obviously trained so that they know how to transform the input appropriately so that the model spits out useful output.
But making the state space variables input dependent creates a problem in terms of computation overhead. They fix the computation problem by designing a machine architecture-aware algorithm that makes the most of modern GPU memory architecture, avoiding moving things in and out of HBM as much as possible.
Tri Dao came up with Flash Attention, which is basically a way to use hardware more efficiently in a Transformer. So this is his jam 100%.
I know this doesn’t add much to understanding the paper, but hopefully it’s better than nothing.
1. Attention is quadratic with context length; RNN with gating (LSTM, GRU, etc) are linear, as are all these new architectures. Early RNN used gating to avoid exploding gradients, these new ideas use theory from dynamical systems that guarantees stability so the gating can focus on memory, rather than solving two problems at once.
2. The models released in the last couple of weeks running up to neurIPS23 (Mamba and Based) included a multi-query associative recall (MQAR) and data-dependence in the gating/selection inspired by multi-headed attention. It turned out these were the main missing ingredients compared to earlier state-space (Hyena and earlier) architectures and made these new models as good as attention in associative recall tasks, and potentially even slightly better than attention in other non-lookup tasks. Of course the huge detail in mamba is the efficient implementation on CUDA; without it the architecture may not make much sense for tasks where transformers are already appropriate.
3. If one does not have to worry too much about context length, a lot of new domains open up: DNA-sequence analysis is a linear task with long dependence; think of analyzing images, videos, or higher dimensional info in terms of streams of tokens (scan the pixels in the way of an old CRT monitor). The early dreams of AI included a continuously evolving single learning trajectory of an agent interacting with an environment continuously, so maybe such dreams will be easier to realize with these infinite-context-length models.
bonus: you didn't ask for it, but as of today the downstream applications of these models for important/practical tasks are largely untested/untuned compared to the rather mature applications of attention, so there may be a little delay before people figure out all the tricks for how to use large pre-trained models of these types. The analogy to the old RNN helps to a degree, but people had super specialized to attention and transformers the last 5 years, so there is a lot of momentum in favor of transformers.
> is Mamba potentially more compute-efficient to train for a similar size of model/dataset?
I would like to understand it too as well ...
Here is the citation from original paper:
"Computation. After the parameters have been transformed from (∆, A, B, C) ↦ (A, B, C), the model can be computed in two ways, either as a linear recurrence (2) or a global convolution (3). Commonly, the model uses the convolutional mode (3) for efficient parallelizable training (where the whole input sequence is seen ahead of time), and switched into recurrent mode (2) for efficient autoregressive inference (where the inputs are seen one timestep at a time)."
So the training is parallelizable, like in RetNet with parallel forward mode.
By default inference is done in the recurrent mode, to have a longest possible context. No chunking available, so it is difficult for me to say how much RAM and VRAM it will consume during the inference ...
I did some minimal testing, mamba uses about 60% of VRAM in comparison to RetNet (parallel forward mode) with the model of the same size and the vocabulary of same size during inference.
AFAIK mamba is continuation of the SSM research, which is basically something called long-convolution.
Instead of doing quadratic attention (computing how much each token attends to every other token) you just "somehow" compute a long (same length as input) convolution kernel, and then you apply the conv1d.
Again, from my limited understanding, it's bit related to applying FFT, doing some matmul and then IFFT back. We know that this works but it's slow. But there are many ways to compute FFT and one of them is with something called butterfly matrices. I think it's just approximation but it's good enough and it's very fast/efficient on current hardware.
To put this in context, quadratic sounds bad, but in practice, subquadratic algos are often slower because of hw limitations. So while there was a lot of excitement about SSM it's not so easy to say that llama is over now. Also, we don't know if mamba will scale up, and the only way to know that is to actually pay few millions for training. But I am optimistic.
Another interesting model from subquadratic family is RWKV. Worth checking, but I think you had a podcast about it :)
BTW: I am self-thought and I've only skimmed the paper some time ago so I might be very wrong.
BTW2: Another thing with attention is that there's usually KV-cache, which helps a lot with performance, and I think you cannot do that with mamba.
Re 3) Even if you don't care about long context length, Mamba is much cheaper per token of auto-regressive output. Each token has to only compute the next step of a linear RNN, the transformer has to attend back over all previous outputs, which rapidly grows in cost and memory.
"Mamba is the world's longest venomous snake with an estimated length of over 150 m"
Had a laugh at that. Really great stuff though, it was nice to have referencing to the arxiv paper so someone like me who generally consumes these things instead of translating them from papers could sort of peak behind the curtains.
This is a dumb question but how hard is it to train the mamba models that are on huggingface? It looks like the largest one is 2.8b - how many GPUs for how long do you need to train that up using a dataset like The Pile?
That's a great question and I would like to know too. It looks like the answer is substantially faster than an equally sized Transformer, and the end result will score better than a Transformer on basically every benchmark. Also it will do inference 3-5x faster in half the RAM.
Thanks for this. I took a stab at unraveling the official CUDA version and never really got around to it after my initial attempt failed. This seems a lot nicer.
Oh my gosh, another one-file PyTorch implementation. This is fantastic. I'd like to hope that some of my previous work (hlb-CIFAR10 and related projects, along with other influences before it like minGPT, DawnBench, etc.) has been able to help push the 'simple, single-file, reduced-complexity' format forward a bit. I personally think that this kind of work is critical to efficient ML research, and that is possibly one of the most important things that we can do for the field today.
Research progresses at the speed of innovation, which progresses with the inverse of experiment runtime, which is definitely and absolutely related to the underlying Kolmogorov Complexity of the code w.r.t. a research/simple-hackery-focused objective.
I really cannot stress enough how important to research tools like this are and how much they've sped up the knowledge discovery process for me personally. Being able to quickly sketch out ideas, often in minutes, and get immediate, high-snr results back has become an indispensable part of my research progress. While we seem to really good at some of the specifics of some of the detailsresearch, and somehow have extremely information-efficient training processes, we have not applied the same logic seemingly on the whole to the entire research field!
Knowledge distillation and/or the MDL (https://en.wikipedia.org/wiki/Minimum_description_length) are excessively important I think to reversing a lot of the constant fluff, cruft, and overly dense thrash-and-hope-you-don't-get-scooped-by-other-researchers-on-marginal-value-topics trend that I think has largely been encouraged by the current paper submission/review/etc process.
I've been wanting to try to get around this and move a bit more towards a slightly better scaling solution recently. One of these things is that I've started distributing my code in 1-file, self-contained, short rough gists as 'code sketches', which shortens dev time and gets rough, unpolished, working code for a concept in people's hands. It seems to work pretty well so far, I hope to continue doing it! <3 :'))))
In any case, this is extremely exciting stuff, and everyone -- please! More code like this! We're researchers on learning data in a largely-scaled way, let's be data-efficient in how we disseminate information as well! It's a dream come true to see a lot more of this stuff coming down the pipeline, fantastic work and keep it coming! <3 :')))) Woop woop woop!!!!
It’s been an exciting 2023 year in no small part because of watching AI research unfold at these crazy speeds. Like you’ve said, these enablers like ArXiV, PyTorch, GitHub, Huggingface, and terse Python code that’s open source are dramatically accelerating the development of this new field.
It’s probably the fastest the human race has ever developed anything of substantial complexity!
The only other place I see this king of velocity is SpaceX, which also launched two cutting edge rockets this year.
Minor potential performance benefit -- it looks like you might be able to fuse the x_proj and dt_proj weights here as x_proj has no bias. This is a thing that's possibly doable simply at runtime if there's any weight-fiddling reqs, I'm guessing the single kernel + bias will still run faster in the end (not sure though! <3 :')))) )
Is there an original paper discussion? I seem to have missed it. It's quite interesting. I didn't catch on to this part:
"We note that full results on context length 8k are missing for the RWKV and RetNet baselines, prior strong recurrent models that can also be interpreted as SSMs, due to a lack of efficient implementation leading to out-of-memory or unrealistic computation requirements."
RetNet doesn't really consume much memory, and with the chunkwise forward implementation, it restricts the VRAM usage to the chunk size. This is the part to test the context length.
Has anyone done some tests on the original Mamba model? How fast is the training on this one in comparison with RetNet in parallel forward mode?
Very cool ive read this line of paper originating from hippo, s4, hyena, mamba etc but can someone please explain how this isnt just an RNN/LSTM variant??
Its latent space transition is linear, instead of nonlinear, so there's a more parallelizable algorithm for advancing time in it. This makes it much more efficient to train and do inference with in GPUs.
The way it keeps all the representation power of LSTMs is by having the transition vary with the input (but still be linear).
Thanks thats helpful. One place where the parallelizability of this method falls short of the transformer is not being able to pack multiple varying length examples into the same array during training with block diagonal attention pattern. If I understand correctly thats not possible with this architecture and its an important practical concern in large scale transformer training.
How long does it generally take between model architectures like Mamba being proposed and the use of these architectures in SotA mega models like GPT or Gemini? IIUC Mamba basically eliminates restrictions on context length which would be awesome to see in the super-mega high performance models.
I re-implemented Mamba myself and this was the first time I had ever worked with einops/einsum. I'm 50/50 on them after this. I found them relatively easy to look at and understand the intent (possibly more so than other representations), but talking extra time to transforms into other primitives (loops, multiplication, etc). I belive using torch.einsum is generally well optimized as well compared to naively looping. All said, I don't know if I'd use it myself working from scratch but it's interesting to know and if I was working in python I might try comparing the speed of einops/sum vs other ways.
I love one file implementations. I hate all these implementations with preprocess_utils.py that imports stuff from model.py that imports stuff again from preprocess_utils.py that imports stuff from ...
I really struggle with dozens and dozens of vocabulary that is being used in the field of machine learning and especially AI. I'm not a beginner at all, but I wonder if there is a comprehensive guide for all those terms that not necessarily explains the technology behind them in detail, but shows their position and relation to each other. like some kind of landscape.
"everyone" seems to know Mamba. I never heard of Mamba. There are constantly new kind of llm popping up, talking about stuff that seems to be obvious.
So, is there some kind of resource like that, not aiming at beginners, but experienced users, coming from other fields of IT?
In fast evolving fields it’s always all about sociology, not canon or pedagogy. Meaning in new fields is created in community (constructionism).
You need to plug into the community and overhear what people are talking about (HN is such a community). You’ll also get a sense of the linguistic subculture (acronyms, lingo etc) much like you learn to talk hip hop if you’re into the hip hop subculture. Much of it will be noise but overall you’ll get a sense of what the community cares about, which helps you narrow what you need to focus on. The subreddit r/localllama is the watering hole for hobbyists right now.
In this particular case, I find it helpful to do syntopical reading (per Mortimer Adler) around LLMs not AI in general. Mamba is interesting to me because I have a background in optimal control and state space models are my bread an butter and it’s fascinating to see them applied in this way.
Side: I’m in my 40s and this isn’t my first rodeo. There will always be new fields and trends emerging — I’ve been through several waves of this (cloud, big data, ML, data science etc) where posts like yours are commonplace. But there is no need to be frustrated. Overhearing conversations is one way to make sense of them instead of feeling lost and waiting for someone to summarize and explain everything to you.
The same applies to academic fields.
Ps also consider you might not need to be on the cutting edge. If you’re not trying to build leading edge stuff, it’s good to wait for the dust to settle — you’ll waste less time following dead ends while the community is figuring out what’s good.
Perhaps the community at r/localllama could train an LLM that knows about the latest developments and explains jargon and papers, updated weekly. Free idea for karma.
1. Select abstract or select all text then copy/paste.
2. Save the PDF and upload with ChatGPT’s document feature.
3. Ask for it, “what’s that well known LLM paper about context and getting lost in the middle?”. It will web search as needed.
You can also do more than summarize. Ask about equations, ask it to make analogies, challenge the key findings as devil’s advocate to learn from different angles. Propose your own ideas.
Use voice to digest topics during your commute and ask tons of questions until you understand.
>"everyone" seems to know Mamba. I never heard of Mamba
Only the "everybody who knows what mamba is" are the ones upvoting and commenting. Think of all the people who ignore it. For me, Mamba is the faster version of Conda [1], and that's why I clicked on the article.
Its extremely common to manage python environments with conda (although it can do much more). If you are unaware of conda, it is unlikely you work with python, and therefore unlikely to be doing much with ML (and LLMs) anyway - its even part of the "getting started" documentation for pytorch.
Conda has been around for a decade and it used to be the primary package manager for everything related to numpy/scipy. Most ML and data science people have heard of it even if they haven't used it.
Mamba is a PoC of the latest SSM architecture for LLMs named S6 and is a dense counterpart to Transformers trained for 300B tokens of the Pile in sizes up to 2.7B. Mamba proves that S6 LLMs train faster, run faster, use less VRAM, result in lower perplexity and better benchmark scores with the same exact training data.
That is actually accurate but probably sounds just as outlandish.
The approachable version is: Mamba is a proof of concept language model which showcases a new LLM architecture called S6 which is a competitor to the Transformer architecture (the 'T' in ChatGPT) and it is better in every measurable way.
It is a very fad driven field. Everyone brands everything. It isn't enough to give things boring titles like, stacked open linear dynamical system with selective observations and learned timestep.
that's half of it, the other half is pure social linguistics.
try talking about stacked open linear dynamical system for more than three times and you're bound to figure out a token that conveys the same but is quicker to produce
it's turtles all the way down with LLM And your comment. people are just trying to maximize their token conversations
Its a new LLM type: instead of transformers it use state-space machines,
which are orders of magnitude faster.
Its currently very new and less coherent than GPT-2.
I didn't know Mamba but the bottom of the page lists comprehensive references.
If you mean the "branding" that is common in ML, which is often criticized, I much prefer it over the jargon used in other fields, e.g. Mathematics. It is nice to have distinguished words to talk about different concepts.
The people that are constantly up to date on this stuff tend to be AI/ML researchers and engineers. In academia, industry research groups, or startups.
They literally get paid to read papers, and implement models on a day-to-day basis.
I wouldn't worry too much not being up to date or things sounding a bit foreign. The names themselves are just that, names, the models themselves tend to be incremental versions of some previous model.
Most of the startups I've chatted with seem to prioritize finding people who build products. The complaint/regret I've heard from 3-5 organizations was hiring researchers.
Researcher is more for highly funded organizations. Starrups can get by with off the shelf models.
> in the field of machine learning and especially AI
Sorry for getting semantical here, but isn't ML a subfield of AI? In other words, I would have expected "... in the field of machine learning and AI in general"
AI is often being used recently for specifically generative AI, which is a subfield of machine learning, which is a subfield of AI in the broader sense.
But I did notice the "References" section in the bottom of the README, which does explain what Mamba is by linking to the original paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" https://arxiv.org/abs/2312.00752
Heavily agree. Ive been following this space quite closely, like most people, only for the past year. But it seems to be still in its experimental phase which in turn brings academics and researchers who tend toward this type of language.
Everybody doesn't know Mamba. You can't stay on top of everything in ML so stop trying. Since you asked, Mamba is a neural architecture based on structured state space models (SSMs) that aims to replace Transformers. For me right now just know that counts as staying on top of things. If I need to know more than that I can have the computer summarize it for me.
I think the glossary is defining variable names as given in the paper. I found this confusing when I originally read the paper as the authors assume that the reader knows what B, L, D and N stand for. I had to use explainpaper to figure it out.
Yes, it is here. This is an implementation designed for education: the main purpose here is to understand the model architecture in a practical sense.
So lines of code and number of files are both meaningful. This is 1 short Python file, which makes it a lot easier to understand than a full optimized implementation.
BERT:
https://github.com/explosion/curated-transformers/blob/main/...
Llama 1/2:
https://github.com/explosion/curated-transformers/blob/main/...
MPT:
https://github.com/explosion/curated-transformers/blob/main/...
With various stuff enabled, including support for TorchScript JIT, PyTorch flash attention, etc.