Mamba: The Easy Way
蟒蛇:简单之道
Oxford, UK — February 23, 2024
牛津,英国 — 2024 年 2 月 23 日
Shared on Hacker News and X
分享在 Hacker News 和 X 上
Today, basically any language model you can name is a Transformer model.
OpenAI’s ChatGPT, Google’s Gemini, and GitHub’s Copilot are all powered by Transformers, to name a few.
However, Transformers suffer from a fundamental flaw: they are powered by Attention, which scales quadratically with sequence length.
Simply put, for quick exchanges (asking ChatGPT to tell a joke), this is fine.
But for queries that require lots of words (asking ChatGPT to summarize a 100-page document), Transformers can become prohibitively slow.1
今天,基本上你能叫得出的任何语言模型都是 Transformer 模型。OpenAI 的 ChatGPT、Google 的 Gemini 和 GitHub 的 Copilot 等都是基于 Transformer 构建的,仅举几例。然而,Transformer 存在一个根本性的缺陷:它们依赖于 Attention 机制,其规模与序列长度呈二次方增长。简单来说,对于快速交流(比如让 ChatGPT 讲一个笑话),这没问题。但对于需要大量文字的查询(比如让 ChatGPT 总结一份 100 页的文档),Transformer 可能会变得过于缓慢。 1
Many models have attempted to solve this problem, but few have done as well as Mamba.
Published two months ago by Albert Gu and Tri Dao, Mamba appears to outperform similarly-sized Transformers while scaling linearly with sequence length.
If you’re looking for an in-depth technical explanation of Mamba, paired with a full Triton implementation, you’re in the wrong place.
Mamba: The Hard Way has already been written by the legend himself, Sasha Rush.
If you haven’t heard of Mamba (or Triton), or you’re looking for a higher-level overview of Mamba’s big ideas, I have just the post for you.
许多模型试图解决这个问题,但很少有做得比 Mamba 更好的。两个月前由 Albert Gu 和 Tri Dao 发表的 Mamba 似乎在规模与 Transformer 相似的情况下,性能更优,并且与序列长度线性扩展。如果你在寻找 Mamba 的深入技术解释,并附带完整的 Triton 实现,那么你找错地方了。《Mamba:艰难之路》已经被传奇人物 Sasha Rush 亲自撰写。如果你还没有听说过 Mamba(或 Triton),或者你正在寻找 Mamba 核心思想的更高层次概述,我正好有一篇帖子适合你。
The prospect of an accurate linear-time language model has gotten many excited about the future of language model architectures (especially Sasha, who has money on the line).
In this blogpost, I’ll try to explain how Mamba works in a way that should be fairly straightforward, especially if you’ve studied a little computer science before.
Let’s get started!
前景是准确的线性时间语言模型,这让很多人对语言模型架构的未来感到兴奋(尤其是 Sasha,他在这上面下了注)。在这篇博客文章中,我将尝试以一种相对简单易懂的方式解释 Mamba 的工作原理,尤其是如果你之前学过一点计算机科学的话。让我们开始吧!
Background: S4 背景:S4
Mamba’s architecture is based primarily on S4, a recent state space model (SSM) architecture.
I’ll summarize the important parts here, but if you want to understand S4 in more detail, I would highly recommend reading another one of Sasha’s blogposts, The Annotated S4.
Mamba 的架构主要基于 S4,这是一种最近的状态空间模型(SSM)架构。以下我将简要总结其中的重要部分,但如果你想要更详细地了解 S4,我强烈建议阅读 Sasha 的另一篇博客文章,《S4 注释》。
At a high level, S4 learns how to map an input
在较高层次上,S4 学习如何通过中间状态
In practice, we always deal with discrete data, such as text.
This requires us to discretize the SSM, transforming our continuous parameters
在实践中,我们总是处理离散数据,例如文本。这要求我们将 SSM 离散化,通过使用一个特殊的第四参数
These equations form a recurrence, similar to what you would see in a recurrent neural network (RNN).
At each step
这些方程构成一个递归,类似于你在循环神经网络(RNN)中看到的。在每一步
In this way, we can essentially use S4 as an RNN to generate one token at a time.
However, what makes S4 really cool is that you can actually also use it as a convolutional neural network (CNN).
In the above example, let’s see what happens when we expand the discrete equations from earlier to try to calculate
这样,我们本质上可以将 S4 用作 RNN 一次生成一个标记。然而,S4 真正酷的地方在于你实际上还可以将其用作卷积神经网络(CNN)。在上面的例子中,让我们看看当我们将之前的离散方程扩展以尝试计算
With
在计算了
Now, notice that
现在,请注意
Since
由于
Importantly, these recurrent and convolutional forms, which I like to call “RNN mode” and “CNN mode,” are mathematically equivalent.
This allows S4 to shape-shift depending on what you need it to do, with no difference in its outputs.
We can compare the differences between these “modes” in Table 1 from the S4 paper, which shows the runtime complexity of training and inference for each form (bold denotes the best result for each metric).3
重要地,这些我称之为“RNN 模式”和“CNN 模式”的循环和卷积形式在数学上是等价的。这使得 S4 可以根据需要执行的任务进行变形,其输出没有差异。我们可以在 S4 论文的表 1 中比较这些“模式”之间的差异,该表显示了每种形式的训练和推理的运行时复杂度(粗体表示每个指标的最佳结果)。 3
Convolution 卷积 | Recurrence 重复 | S4 | |
Training 训练 | |||
Parallel 并行 | Yes 是的 | No 无 | Yes 是的 |
Inference 推理 |
Notice that CNN mode is better for training, while RNN mode is better for inference.
In CNN mode, we can take advantage of parallelism to train across many examples, all at once.
In RNN mode, although we can only calculate one step at a time, each step requires exactly the same amount of work.
Because S4 can use both modes, it essentially gets the best of both worlds: fast training, and even faster inference.
注意,CNN 模式更适合训练,而 RNN 模式更适合推理。在 CNN 模式下,我们可以利用并行性同时训练许多示例。在 RNN 模式下,尽管我们一次只能计算一步,但每一步都需要相同的工作量。因为 S4 可以使用这两种模式,它实际上得到了两者的最佳结合:快速训练,以及更快的推理。
Idea #1: Selectivity 想法#1:选择性
Now we can move on to the first major idea introduced by Mamba: selectivity.
Let’s recall the two equations that define the discrete form of S4:
现在我们可以继续探讨 Mamba 提出的第一个主要观点:选择性。让我们回顾一下定义 S4 离散形式的两个方程:
Note that in S4, our discrete parameters
请注意,在 S4 中,我们的离散参数
The authors argue that selectivity, or input-dependence, is important for a number of tasks.
Here’s how I like to think about it: because S4 does not have selectivity, it is forced to treat all parts of the input exactly the same.
However, when you’re reading a sentence, some words inevitably matter more than others.
Imagine we have a model that classifies sentences based on intent, and we give it the sentence: “I want to order a hamburger.”
Without selectivity, S4 spends the same amount of “effort” processing each word.
Click on the buttons below to see what happens as the sentence is processed, one word at a time.
作者认为选择性,或输入依赖性,对于许多任务来说很重要。我这样思考:因为 S4 没有选择性,它被迫对输入的所有部分进行完全相同的处理。然而,当你阅读一个句子时,一些词不可避免地比其他词更重要。想象我们有一个根据意图对句子进行分类的模型,我们给它这个句子:“我想点一个汉堡。”如果没有选择性,S4 在处理每个词时花费相同的“努力”。点击下面的按钮,看看句子逐词处理时会发生什么。
Click on the arrows to update the hidden state
点击箭头以更新隐藏状态
我想点一个汉堡
Hidden State 隐藏状态
(这是一个过于简化的说法,但它应该能让你对所发生的事情有一个大致的了解。)
But if you were a model trying to classify the intent of this sentence, you would probably want to “focus” more on some words than others.
How much value do the words “want” and “to” really contribute to the underlying meaning of this sentence?
In reality, it would be great if we could spend more of our limited mental energy on words like “order,” to know what the user wants to do, and “hamburger,” to know what the user is ordering.
By making model parameters a function of the input, Mamba makes it possible to “focus” on the parts of the input that are more important for the task at hand.
但如果你是一个试图分类这个句子意图的模型,你可能更希望“关注”一些词比其他词更多。单词“想要”和“到”实际上对这句话的潜在意义贡献了多少价值?实际上,如果我们能更多地用我们有限的脑力关注像“订单”这样的词,以了解用户想要做什么,以及“汉堡”这样的词,以了解用户在订购什么,那就太好了。通过使模型参数成为输入的函数,Mamba 使得关注对当前任务更重要的输入部分成为可能。
Click on the arrows to update the hidden state
点击箭头以更新隐藏状态
我想点一个汉堡
Hidden State 隐藏状态
(也是一种过度简化。)
However, selectivity presents us with a problem.
Let’s think back to the convolutional kernel
然而,选择性给我们带来了一个问题。让我们回顾一下我们之前计算过的卷积核
In S4, we could precompute this kernel, save it, and multiply it with the input
在 S4 中,我们可以预先计算这个核函数,保存它,并将其与输入
This posed a problem for Mamba’s authors: training in RNN mode is really slow.
Imagine we’re training our model on a sequence with 1,000 tokens.
A CNN would essentially compute a dot product between its kernel and the input vector, and it can do these computations in parallel.
By comparison, an RNN would need to update its hidden state 1,000 times in sequence.
This slow training time of RNNs is more or less what has prevented them from ever really taking off, and it led Mamba’s authors to their second big idea.
这给 Mamba 的作者们提出了一个问题:在 RNN 模式下训练非常慢。想象一下,我们正在对一个包含 1,000 个标记的序列进行模型训练。CNN 实际上会计算其核与输入向量的点积,并且可以并行进行这些计算。相比之下,RNN 需要按顺序更新其隐藏状态 1,000 次。RNN 这种缓慢的训练时间基本上阻止了它们真正起飞,这也导致了 Mamba 的作者们产生了第二个重大想法。
Idea #2: Fast training without convolutions
想法#2:无需卷积的快速训练
The second major idea of Mamba involves training in RNN mode very, very quickly.
At some point, Gu and Dao realized that their recurrence was very similar to a scan algorithm, also known as a prefix sum.
To compute a prefix sum, we need to take an input array
Mamba 的第二大思想涉及在 RNN 模式下非常、非常快速地进行训练。在某个时刻,Gu 和 Dao 意识到他们的递归与一种扫描算法非常相似,也称为前缀和算法。为了计算前缀和,我们需要取一个输入数组
Now let’s draw out the process for updating Mamba’s hidden state in RNN mode.
Wait a minute…
现在让我们绘制 Mamba 在 RNN 模式中更新隐藏状态的过程。等一下……
Let’s think about this.
If we had to formalize a prefix sum, we could write it out as the following equation:
让我们来思考这个问题。如果我们需要将前缀和形式化,我们可以将其写成以下方程式:
This equation forms a recurrence: at each step, we compute the new value by adding the previous stored value to the current input.
Now, let’s look again at the recurrence for updating Mamba’s hidden state.
这个方程构成一个递归:在每一步,我们通过将之前存储的值与当前输入相加来计算新值。现在,让我们再次看看更新 Mamba 隐藏状态的递归。
These are really, really similar!5
And here’s the cool part: while computing a prefix sum may seem inherently sequential in nature, we actually have efficient parallel algorithms for this task!
In the diagram below, we can see a parallel prefix sum algorithm in action, where each vertical line represents one item in our array.
这些真的很相似! 5 而且这里有个酷点:虽然计算前缀和看起来本质上是顺序的,但实际上我们为此任务有高效的并行算法!在下面的图中,我们可以看到并行前缀和算法的实际运行情况,其中每条垂直线代表我们数组中的一个元素。
Take a second to convince yourself that this algorithm works: choose any vertical line, start at the top, and work your way down, tracing each addition back to the array’s first few items.
By the time you reach the bottom, you should have the sum of all items to the left of your line.
For example, you can see that the array’s third element receives the added value of the second element at the end, after the first element is added to the second element at the beginning.
As a result, the third element contains the sum of the first, second, and third elements by the time the parallel scan is finished.
花点时间说服自己这个算法是有效的:选择任何一条垂直线,从顶部开始,逐行向下,追踪每个加法操作回到数组的前几个元素。当你到达底部时,你应该得到了你线左侧所有元素的总和。例如,你可以看到数组的第三个元素在并行扫描完成时,接收到了第二个元素在末尾的附加值,这是在第一个元素被添加到第二个元素的开头之后。因此,第三个元素包含了第一个、第二个和第三个元素的总和。
If we were running this algorithm in a single thread, with no parallelism, it would take longer than if we were just adding the values together in sequence.
But GPUs have lots of processors, allowing for highly parallel computation.
As a result, we can compute this prefix sum (or scan) operation in roughly
如果我们只用单线程运行这个算法,没有并行性,它所需的时间会比我们按顺序相加这些值要长。但是 GPU 拥有大量的处理器,允许进行高度并行的计算。因此,我们可以在大约
So Mamba’s authors realized that if they wanted to train efficiently in RNN mode, they could probably use a parallel scan.
Since PyTorch does not currently have a scan implementation, Mamba’s authors wrote one themselves, and the results weren’t great.
因此,Mamba 的作者意识到,如果他们想在 RNN 模式下高效训练,可能可以使用并行扫描。由于 PyTorch 目前没有扫描实现,Mamba 的作者自己编写了一个,但结果并不理想。

信用:顾和道,2023
In the figure above, you can see that their PyTorch-based scan implementation (green) is always slower than FlashAttention-2 (blue), the fastest available “exact Attention” implementation.6
At a sequence length of 128,000 tokens, where the scan almost seems to catch up in runtime, it runs out of memory.
In order for Mamba to be practical, it needed to be faster.
This brought Mamba’s authors to Dao’s prior work on FlashAttention.
在上面的图中,您可以看到他们的基于 PyTorch 的扫描实现(绿色)总是比最快的“精确注意力”实现 FlashAttention-2(蓝色)慢。在序列长度为 128,000 个标记时,扫描在运行时间上几乎似乎追上了,但它耗尽了内存。为了让 Mamba 变得实用,它需要更快。这促使 Mamba 的作者回到 Dao 之前关于 FlashAttention 的工作。
Review: FlashAttention 闪注意力回顾
FlashAttention is a very fast implementation of Attention.
When published, FlashAttention trained BERT-large 15% faster than the previous fastest training time, and it was 3 times faster than the widely-used HuggingFace implementation of GPT-2.
闪速注意力是注意力机制的非常快速实现。发布时,闪速注意力训练 BERT-large 的速度比之前最快的训练时间快 15%,并且比广泛使用的 HuggingFace GPT-2 实现快 3 倍。
In a nutshell, FlashAttention’s key insight has to do with the speeds at which different operations run on your GPU.
They realized that some GPU operations are compute-bound, meaning they are limited by the speed at which your GPU performs computations.
However, other operations are memory-bound, meaning they are limited by the speed at which your GPU is able to transfer data.
总的来说,FlashAttention 的关键洞察与不同操作在 GPU 上运行的速率有关。他们意识到,某些 GPU 操作是计算受限的,这意味着它们受限于 GPU 执行计算的速度。然而,其他操作是内存受限的,这意味着它们受限于 GPU 传输数据的速度。
Imagine you and a friend are playing a game: your friend has to run 50 meters to deliver two numbers to you, which you then need to multiply by hand.
A timer starts when your friend begins running, and ends when you get the answer.
Let’s say the numbers you need to multiply are 439,145,208 and 142,426,265.
It would take you awhile to multiply these by hand.
Your friend might take 5 seconds to deliver the numbers, but you might take 60 seconds to perform the multiplication.
As a result, you are both compute-bound, since most of your time is spent on computation.
Now, imagine the numbers you need to multiply are 4 and 3.
While your friend still takes 5 seconds to run 50 meters, you can compute this result instantly.
Now, you are both memory-bound, since most of your time is spent transferring data.
想象你和一位朋友正在玩一个游戏:你的朋友需要跑 50 米将两个数字传给你,然后你需要手动将这两个数字相乘。当你的朋友开始跑时计时开始,当你得到答案时计时结束。假设你需要相乘的数字是 439,145,208 和 142,426,265。手动相乘这些数字需要你花费一些时间。你的朋友可能需要 5 秒钟来传递数字,但你可能需要 60 秒钟来完成乘法。因此,你们俩都处于计算受限状态,因为大部分时间都花在了计算上。现在,假设你需要相乘的数字是 4 和 3。尽管你的朋友仍然需要 5 秒钟来跑 50 米,但你可以立即计算出这个结果。现在,你们俩都处于内存受限状态,因为大部分时间都花在了数据传输上。
In this analogy, your GPU is essentially racing to move data into the right places to perform its computations.
For example, let’s consider a masking operation.
To compute a masked vector, your GPU simply needs to erase data values whenever the mask is equal to zero (and keep them the same whenever it is equal to one).
If we used
在这个类比中,您的 GPU 本质上是在竞速将数据移动到正确的位置以执行其计算。例如,让我们考虑一个掩码操作。为了计算一个掩码向量,您的 GPU 只需在掩码等于零时擦除数据值(当掩码等于一时保持不变)。如果我们用
Since this is extremely easy to compute, your GPU ends up spending most of its time transferring memory, to move the data and mask matrices into the right places for computation.
This means that masking is memory-bound.
On the other hand, matrix multiplication involves lots and lots of additions and multiplications.
Because so much more time is spent on computation than memory transfers, matrix multiplication is compute-bound.
With this in mind, let’s look at a breakdown of the computations performed during Attention (matmul = matrix multiplication).
由于这非常容易计算,您的 GPU 大部分时间都花在传输内存上,将数据和掩码矩阵移动到计算的正确位置。这意味着掩码是内存受限的。另一方面,矩阵乘法涉及大量的加法和乘法。由于在计算上花费的时间比内存传输多得多,矩阵乘法是计算受限的。考虑到这一点,让我们看看在注意力(matmul = 矩阵乘法)过程中进行的计算分解。

信用:Dao 等人,2022 年
It turns out that dropout, softmax, and masking, which make up the bulk of Attention’s runtime, are all memory-bound.
This means that most of the time we spend computing Attention is simply spent waiting for your GPU to move around data.
With this in mind, I assume FlashAttention’s authors wondered, how can we speed up operations that are bounded by the speed of memory transfers?
结果显示,构成注意力机制运行时主要部分的 dropout、softmax 和 masking 都是内存受限的。这意味着我们花费在计算注意力的大部分时间实际上是在等待 GPU 移动数据。考虑到这一点,我推测 FlashAttention 的作者们可能会想,如何加快受内存传输速度限制的操作?
This led FlashAttention’s authors to another key realization: GPU memory has two major regions.
One of these, high-bandwidth memory (HBM), is really big, but really slow.
The other one, static random-access memory (SRAM), is really small, but really fast.
Let’s break down the differences between these regions on an A100 GPU:
这导致 FlashAttention 的作者们有了另一个关键的认识:GPU 内存有两个主要区域。其中之一,高带宽内存(HBM),非常大,但非常慢。另一个,静态随机存取存储器(SRAM),非常小,但非常快。让我们以 A100 GPU 为例,分析这两个区域之间的差异:

信用:Dao 等人,2022 年
FlashAttention’s authors realized that you can compute memory-bound operations more efficiently if you’re extra careful about how you use these regions of GPU memory.
They use an approach called tiling, in which small portions of your data are moved from HBM (slower) to SRAM (faster), computed in SRAM, and then moved back from SRAM to HBM.
This makes FlashAttention really, really fast, while still being numerically equivalent to Attention.
FlashAttention 的作者意识到,如果你在使用 GPU 内存的这些区域时格外小心,就可以更高效地计算内存受限的操作。他们采用了一种称为分块的方法,将数据的小部分从 HBM(较慢)移动到 SRAM(较快),在 SRAM 中进行计算,然后再从 SRAM 移动回 HBM。这使得 FlashAttention 非常非常快,同时仍然在数值上等同于 Attention。

信用:Dao 等人,2022 年
The details of how this works are fascinating, and I encourage you to check out the FlashAttention paper to learn more.
However, for the purpose of understanding Mamba, this is basically all you need to know.
这个工作原理的细节非常吸引人,我鼓励您查阅 FlashAttention 论文以了解更多信息。然而,为了理解 Mamba,这基本上就是您需要知道的所有内容。
Back to Mamba 回到 Mamba
Remember that before we started this tangent on FlashAttention, we were trying to speed up our parallel scan implementation.
Here is the same graph from earlier, where we can see that the scan implementation in PyTorch (green) is always slower than FlashAttention, the fastest “exact” Transformer (blue).7
记得在我们开始关于 FlashAttention 的讨论之前,我们正在尝试加速我们的并行扫描实现。这里是有相同图表,我们可以看到 PyTorch(绿色)的扫描实现始终比最快的“精确”Transformer FlashAttention(蓝色)慢。 7

信用:顾和道,2023
It turns out that if you take this same memory-aware tiling approach when computing a scan, you can speed things up a lot.
With this optimization in place, Mamba (red) is now faster than FlashAttention-2 (blue) at all sequence lengths.
结果显示,如果在扫描计算时采用相同的内存感知瓦片方法,可以大大加快速度。采用这种优化后,Mamba(红色)在所有序列长度上现在都比 FlashAttention-2(蓝色)更快。

信用:顾和道,2023
These results show that as far as speed goes, Mamba is practical, operating at a faster speed than the fastest exact Transformers.
But is it any good at language modeling?
这些结果表明,在速度方面,Mamba 是实用的,其运行速度比最快的精确 Transformer 还要快。但是它在语言建模方面表现如何呢?
Results 结果
Gu and Dao evaluate Mamba on a number of sequence modeling tasks involving language, genomics, and audio.
I’m not as familiar with the latter two domains, but the results look cool: Mamba establishes state-of-the-art performance when modeling DNA from the Human Genome project, and audio from a piano music dataset.
However, it’s the language results that have gotten many people excited.
A lot of the online discourse about Mamba has focused on Figure 4, which I’ve included below.
顾和道评估了 Mamba 在涉及语言、基因组学和音频的多个序列建模任务上的表现。我对后两个领域不太熟悉,但结果看起来很酷:Mamba 在模拟人类基因组计划中的 DNA 和钢琴音乐数据集中的音频时,建立了最先进的性能。然而,让人们兴奋的是语言结果。关于 Mamba 的许多在线讨论都集中在下面的图 4 上。

信用:顾和道,2023
In this graph, model size increases to the right, and language modeling performance improves as you go further down.8
This means that the best models should be down and to the left: small (and therefore fast), and also very good at modeling language.
Since Gu and Dao are academics, they don’t have thousands of GPUs available to train a GPT-4-sized model, so they made this comparison by training a bunch of smaller models, around 125M to 1.3B parameters.
As the graph above shows, the results look really promising.
When compared to other models of similar sizes, Mamba appears to be the best at modeling language.
在此图中,模型大小向右增加,随着向下移动,语言建模性能得到提升。这意味着最佳模型应位于下方左侧:体积小(因此速度快),同时也非常擅长建模语言。由于 Gu 和 Dao 是学者,他们没有成千上万的 GPU 来训练一个 GPT-4 大小的模型,所以他们通过训练一系列较小的模型进行了比较,参数量约为 1.25 亿到 13 亿。如图所示,结果看起来非常有希望。与其他类似大小的模型相比,Mamba 似乎在语言建模方面表现最佳。
What next? 下一步是什么?
I really enjoyed writing this blogpost, as I think Mamba innovates on language modeling in a pretty unique and interesting way!
Unfortunately, a few reviewers didn’t agree: Gu and Dao planned to present Mamba at ICLR in May, but their paper was rejected a couple weeks ago, causing some bewildered reactions online.
我真的很喜欢写这篇博客,因为我认为 Mamba 在语言建模方面以非常独特和有趣的方式进行了创新!不幸的是,一些审稿人并不认同:Gu 和 Dao 原计划在五月的 ICLR 会议上展示 Mamba,但他们的论文几周前被拒绝,这在网络上引起了一些困惑的反应。
I would guess Gu and Dao are working now on the next version of the paper, and I would also imagine some companies with more GPUs than they know what to do with are currently trying to figure out whether Mamba’s performance holds up at larger model sizes.
As we continue to want models that can process more and more tokens at once, linear-time models such as Mamba might someday provide an answer if they can demonstrate good performance.
Until then, we can keep hacking away on our lame, old-school Transformers.
我猜测 Gu 和 Dao 现在正在撰写论文的下一个版本,同时我也想象一些拥有比他们能处理的更多的 GPU 的公司正在试图弄清楚 Mamba 在更大模型尺寸上的性能是否稳定。随着我们持续追求能够一次性处理更多和更多标记的模型,像 Mamba 这样的线性时间模型或许有一天能提供答案,如果它们能展示出良好的性能。在此之前,我们只能继续在我们的陈旧、老式的 Transformer 上不断进行改进。