Dec 5, 2020 · 13 min read
2020 年 12 月 5 日 · 阅读需 13 分钟
HiPPO: Recurrent Memory with Optimal Polynomial Projections
HiPPO:具有最优多项式投影的循环记忆
Albert Gu*, Tri Dao*, Stefano Ermon, Atri Rudra, and Chris Ré
Many areas of machine learning require processing sequential data in an online fashion.
For example, a time series may be observed in real time where the future needs to be continuously predicted, or an agent in a partially-observed environment must learn how to encode its cumulative experience into a state in order to navigate and make decisions.
The fundamental problem in modeling long-term and complex temporal dependencies is memory: storing and incorporating information from previous time steps.
许多机器学习领域需要在线处理顺序数据。例如,可能实时观察时间序列,需要持续预测未来,或者在部分观测环境中的智能体必须学习如何将其累积经验编码到状态中以便于导航和决策。建模长期和复杂时间依赖性的根本问题在于记忆:存储和整合来自先前时间步的信息。
However, popular machine learning models suffer from forgetting: they are either built with fixed-size context windows (e.g. attention), or heuristic mechanisms that empirically suffer from from a limited memory horizon (e.g., because of "vanishing gradients").
然而,流行的机器学习模型存在遗忘问题:它们要么使用固定大小的上下文窗口(例如,注意力机制),要么使用经验上存在有限记忆范围的启发式机制(例如,“梯度消失”)。
This post describes our method for addressing the fundamental problem of incrementally maintaining a memory representation of sequences from first principles.
这篇帖子描述了我们从第一性原理出发,逐步维护序列记忆表示的基本方法。
Many areas of machine learning require processing sequential data in an online fashion.
For example, a time series may be observed in real time where the future needs to be continuously predicted, or an agent in a partially-observed environment must learn how to encode its cumulative experience into a state in order to navigate and make decisions.
The fundamental problem in modeling long-term and complex temporal dependencies is memory: storing and incorporating information from previous time steps.
许多机器学习领域需要在线处理顺序数据。例如,可能实时观察时间序列,需要持续预测未来,或者在部分观测环境中的智能体必须学习如何将其累积经验编码到状态中以便于导航和决策。建模长期和复杂时间依赖性的根本问题在于记忆:存储和整合来自先前时间步的信息。
However, popular machine learning models suffer from forgetting: they are either built with fixed-size context windows (e.g. attention), or heuristic mechanisms that empirically suffer from from a limited memory horizon (e.g., because of "vanishing gradients").
然而,流行的机器学习模型存在遗忘问题:它们要么使用固定大小的上下文窗口(例如,注意力机制),要么使用经验上存在有限记忆范围的启发式机制(例如,“梯度消失”)。
This post describes our method for addressing the fundamental problem of incrementally maintaining a memory representation of sequences from first principles. We will:
本文描述了我们从第一性原理出发,逐步维护序列记忆表示的基本方法。我们将:
-
find a technical formulation of this problem that we can analyze mathematically, and derive a closed-form solution with the HiPPO framework
找到这个问题的技术性表述以便进行数学分析,并利用 HiPPO 框架推导出封闭形式的解 -
see that our method can be easily integrated into end-to-end models such as RNNs, where our framework both generalizes previous models, including the popular LSTM and GRU, and improves on them, achieving state of the art on permuted MNIST, a popular benchmark for long-range memory.
看到我们的方法可以很容易地集成到端到端模型中,例如 RNN,我们的框架既概括了之前的模型,包括流行的 LSTM 和 GRU,又对其进行了改进,在置换 MNIST(一个流行的长程记忆基准测试)上取得了最先进的结果。 -
show how insights from the framework reveal methods with distinct theoretical properties -- we highlight a particular model HiPPO-LegS, which is computationally efficient, provably alleviates vanishing gradients, and is the first known method to display "timescale robustness"!
展示框架中的见解如何揭示具有不同理论属性的方法——我们重点介绍了一种特定模型 HiPPO-LegS,它计算效率高,可证明地减轻梯度消失问题,并且是第一个已知显示“时间尺度鲁棒性”的方法!
Our paper was accepted to Neurips 2020 as a Spotlight, and our code is publically available with PyTorch and Tensorflow implementations.
我们的论文被 Neurips 2020 接收为 Spotlight,我们的代码已公开发布,包含 PyTorch 和 Tensorflow 实现。
Online Function Approximation: A Formalism for Incremental Memory Representations
在线函数逼近:增量式记忆表示的形式化方法
Our first insight is to move from discrete-time to the continuous-time setting, which is often easier to analyze theoretically. We ask the following very natural question: given a continuous function (in one dimension) , can we maintain a fixed-size representation at all times such that optimally captures the history of from times to ?
我们的第一个见解是从离散时间转向连续时间设置,这在理论上通常更容易分析。我们提出以下非常自然的问题:给定一个连续函数(一维) ,我们能否始终保持一个固定大小的表示 ,使得 最优地捕捉 从时间 到 的历史?
However, this problem is not fully well-defined yet -- we need to specify:
然而,这个问题尚未完全明确——我们需要具体说明:
-
Quality of approximation: What is the "optimal approximation" of the function's history? We need to specify a measure (or weight function) that tells us how much we care about every time in the past.
逼近质量:函数历史的“最优逼近”是什么?我们需要指定一个度量(或权重函数)来告诉我们过去每个时间点有多重要。 -
Basis: How can we compress a continuous function into a fixed-length vector? We can project the function onto a subspace of dimension and store the coefficients of its expansion in any basis. For simplicity, we will assume that we are working with the polynomial basis throughout this post.
基础:如何将连续函数压缩成固定长度的向量?我们可以将函数投影到维度为 的子空间,并存储其在任何基下的 个展开系数。为简便起见,我们假设在本文中始终使用多项式基。
Intuitively, we can think of the memory representation as being
直观地,我们可以将记忆表示 视为 历史的最佳多项式逼近的系数向量。
The HiPPO Framework (High-Order Polynomial Projection Operator)
HiPPO 框架(高阶多项式投影算子)
Notice that given a measure (and assuming the polynomial basis), the online function approximation problem is now fully specified!measure That is, given any input function , the desired coefficient vectors , which are our desired memory representation, are completely defined. The question remains -- how do we calculate them?
请注意,给定一个度量(并假设多项式基),在线函数逼近问题现在完全指定了! measure 也就是说,给定任何输入函数 ,所需的系数向量 (即我们所需的内存表示)都完全定义了。问题仍然是——我们如何计算它们?
The HiPPO framework formalizes this problem and provides machinery to compute the solution.
Although the desired coefficients are rather abstractly defined as the implicit solution to an approximation problem, there is amazingly a closed-form solution that's easy to compute.
We'll leave the technical details to the full paper, but we'll note that they leverage classic tools from approximation theory such as orthogonal polynomialsop.
In the end, the solution takes on the form of a simple linear differential equation, which is called the HiPPO operator:
HiPPO 框架将这个问题形式化,并提供计算解决方案的机制。尽管所需的系数 被相当抽象地定义为逼近问题的隐式解,但令人惊讶的是存在一个易于计算的闭式解。我们将把技术细节留给全文,但需要注意的是,它们利用了逼近理论中的经典工具,例如正交多项式 op 。最终,解的形式是一个简单的线性微分方程,称为 HiPPO 算子:
In short, the HiPPO framework takes a family of measures, and gives an ODE with closed-form transition matrices . These matrices depend on the measure, and following these dynamics finds the coefficients that optimally approximate the history of according to the measure.
简而言之,HiPPO 框架采用一系列度量,并给出一个具有封闭形式转移矩阵 的常微分方程。这些矩阵取决于度量,遵循这些动力学可以找到最优逼近 历史的系数 。
Instantiations of HiPPO HiPPO 的实例
Figure 2 shows some concrete examples of HiPPO. We show two of the simplest family of measures, based off uniform measures. The translated Legendre measure on the left uses a fixed-length sliding window; in other words, it cares about recent history. On the other hand, the scaled Legendre measure uniformly weights the entire history up to the current time.
In both cases, the HiPPO framework produces closed-form formulas for the corresponding ODEs which are shown for completeness (the transition matrices are actually quite simple!).
图 2 展示了一些 HiPPO 的具体例子。我们展示了两个最简单的度量家族,它们基于均匀度量。左侧的 Legendre 变换度量使用固定长度的滑动窗口;换句话说,它关注的是近期历史。另一方面,缩放的 Legendre 度量则均匀地加权直到当前时间为止的整个历史。在这两种情况下,HiPPO 框架都为相应的 ODE 生成了封闭形式的公式,这些公式为了完整性而显示(状态转移矩阵实际上非常简单!)。
From continuous-time to discrete-time
从连续时间到离散时间
There is one more detail called discretization. By using standard techniques for approximating the evolution of dynamical systems, the continuous-time HiPPO ODE can be converted to a discrete-time linear recurrence. Additionally, this step allows extensions of HiPPO to flexibly handle irregularly-sampled or missing data: simply evolve the system according to the given timestamps.
还有一个细节叫做离散化。通过使用逼近动力系统演化的标准技术,可以将连续时间 HiPPO 微分方程转换为离散时间线性递推关系。此外,这一步允许 HiPPO 的扩展灵活地处理不规则采样或缺失数据:只需根据给定的时间戳演化系统即可。
For the practitioner: To construct a memory representation of an input sequence , HiPPO is implemented as the simple linear recurrence where the transition matrices have closed-form formulas. That's it!
对于实践者:为了构建输入序列 的记忆表示 ,HiPPO 实现为简单的线性递归 ,其中转移矩阵 具有封闭形式的公式。就是这样!
Hippos in the wild: integration into ML models
野生河马:融入机器学习模型
At its core, HiPPO is a simple linear recurrence that can be integrated into end-to-end models in many ways. We focus on a recurrent neural network (RNN) due to their connection to dynamics systems involving a state evolving over time, just as in HiPPO. The HiPPO-RNN is the simplest way to perform this integration:
HiPPO 的核心是一个简单的线性递推关系,它可以以多种方式集成到端到端模型中。我们关注循环神经网络 (RNN),因为它们与涉及随时间演变的状态的动力系统有关,就像在 HiPPO 中一样。HiPPO-RNN 是执行此集成的最简单方法:
-
Start with a standard RNN recurrence that evolves a hidden state by any nonlinear function given the input
从一个标准的 RNN 递归 开始,它通过任何非线性函数 根据输入 来演化隐藏状态 -
Project the state down to a lower dimension feature
将状态投影到更低维度的特征 -
Use the HiPPO recurrence to create a representation of the history of , which is also fed back into
使用 HiPPO 循环递归创建 来表示 的历史,并将该表示反馈到
Special cases of the HiPPO-RNN
HiPPO-RNN 的特殊情况
Those familiar with RNNs may notice this looks very similar to cell diagrams for other models such as LSTMs. In fact, several common models are closely related:
熟悉 RNN 的人可能会注意到,这看起来与 LSTM 等其他模型的单元图非常相似。事实上,几种常见的模型密切相关:
-
The most popular RNN models are the LSTM and GRU, which rely on a gating mechanism. In particular, the cell state of an LSTM performs the recurrence , where are known as the "forget" and "input" gates. Notice the similarity to the HiPPO recurrence . In fact, these gated RNNs can be viewed as as a special case of HiPPO with low-order (N=1) approximations and input-dependent discretization! So HiPPO sheds light on these popular models and shows how the gating mechanism, which was originally introduced as a heuristic, could have been derived.
最流行的 RNN 模型是 LSTM 和 GRU,它们依赖于门控机制。特别是,LSTM 的细胞状态执行递归 ,其中 被称为“遗忘”和“输入”门。注意它与 HiPPO 递归 的相似性。事实上,这些门控 RNN 可以看作是低阶(N=1)近似和输入相关的离散化的 HiPPO 特例!因此,HiPPO 阐明了这些流行的模型,并展示了最初作为启发式引入的门控机制是如何推导出来的。 -
The HiPPO-LegT model, which is the instantiation of HiPPO for the translated Legendre measure, is exactly equivalent to a recent model called the Legendre Memory Unitlmu. Our proof is also much shorter, and just involves following the steps of the HiPPO framework!
HiPPO-LegT 模型(HiPPO 在 Legendre 测度上的具体实现)与最近提出的 Legendre 记忆单元 lmu 模型完全等效。我们的证明也简短得多,只需遵循 HiPPO 框架的步骤即可!
Elephants Hippos never forget
大象河马从不忘记
Let's take a look at how these models perform on benchmarks.
First, we test if HiPPO solves the problem it was designed to -- online function approximation.
Figure 4 shows that it can approximate a sequence of a million time steps with good fidelity. Keep in mind that this works while processing the function online with a limited budget of hidden units; it could have reconstructed the partial function at any point in time.
让我们看看这些模型在基准测试中的表现。首先,我们测试 HiPPO 是否解决了其设计目标——在线函数逼近。图 4 显示它可以高保真地逼近百万时间步长的序列。请记住,这是在使用有限的隐藏单元在线处理函数时实现的;它可以在任何时间点重建部分函数。
Second, we test on the standard Permuted MNIST benchmark,
where models must process the input image one pixel at a time and output a classification after consuming the entire sequence.
This is a classic benchmark for testing long-term dependencies in sequence models, since they must remember inputs from almost 1000 time steps ago.
其次,我们在标准的置换 MNIST 基准测试中进行了测试,其中模型必须一次处理一个像素的输入图像,并在消耗整个序列后输出分类结果。这是一个测试序列模型中长期依赖性的经典基准测试,因为它们必须记住近 1000 个时间步之前的输入。
Multiple instantiations of our HiPPO framework, including the HiPPO-LegS and HiPPO-LegT models described above, set state-of-the-art over other recurrent models by a significant margin, achieving 98.3% test accuracy compared to the previous best of 97.15%.
In fact, they even outperform non-recurrent sequence models that use global context such as dilated convolutions and transformers.
Full results can be found in Tables 4+5.
我们提出的 HiPPO 框架的多个实例,包括上面描述的 HiPPO-LegS 和 HiPPO-LegT 模型,显著超越了其他循环模型,达到了 98.3%的测试准确率,而之前的最佳结果为 97.15%。事实上,它们甚至优于使用全局上下文(如扩张卷积和 Transformer)的非循环序列模型。完整结果见表 4 和表 5。
Timescale Robustness of HiPPO-LegS
HiPPO-LegS 的时间尺度鲁棒性
Lastly, we'll explore some theoretical properties of our most interesting model, corresponding to the Scaled Legendre (LegS) measure.
As motivation, the discerning reader may be wondering by this point: What's the difference between different instantiations of HiPPO? How does the measure influence the model?
Here are some examples of how intuitive interpretation of the measure translates into theoretical properties of the downstream HiPPO model:
最后,我们将探讨我们最有趣的模型(对应于缩放勒让德(LegS)测度)的一些理论特性。作为动机,有洞察力的读者此时可能想知道:不同实例的 HiPPO 之间有什么区别?该测度如何影响模型?以下是一些关于测度的直观解释如何转化为下游 HiPPO 模型的理论特性的例子:
-
Gradient bounds: Since this measure says that we care about the entire past, information should propagate well through time. Indeed, we show that gradient norms of the model decay polynomially in time, instead of exponentially (i.e. the vanishing gradient problem for vanilla RNNs).
梯度边界:由于此度量表明我们关心整个过去,因此信息应该能够很好地随时间传播。事实上,我们证明模型的梯度范数随时间多项式衰减,而不是指数衰减(即普通 RNN 的梯度消失问题)。 -
Computational efficiency: The transition matrices actually have special structure, and the recurrence can be computed in linear instead of quadratic time. We hypothesize that these efficiency properties are true in general (i.e., for all measures), and are related more broadly to efficiency of orthogonal polynomials and their associated computations (e.g. [link to SODA paper])
计算效率:转移矩阵 实际上具有特殊的结构,递推可以在线性时间而不是二次时间内计算。我们假设这些效率特性普遍成立(即,对于所有度量),并且更广泛地与正交多项式及其相关计算的效率有关(例如,[SODA 论文链接]) -
Timescale robustness: Most interestingly, the scaled measure is agnostic to how fast the input function evolves; Figure 5 illustrates how HiPPO-LegS is intuitively dilation equivariant.
时间尺度鲁棒性:最有趣的是,缩放后的度量与输入函数的演化速度无关;图 5 说明了 HiPPO-LegS 是如何直观地具有膨胀等变性的。
The table shows results on a trajectory classification dataset with distribution shift between the training and test sequences (i.e., arising from time series being sampled at different rates at deployment); HiPPO is the only method that can generalize to new timescales!
表中显示了轨迹分类数据集的结果,训练序列和测试序列之间存在分布偏移(即,由部署时以不同速率采样时间序列引起);HiPPO 是唯一能够泛化到新时间尺度的方法!
Generalization 泛化 | LSTM | GRU-D | ODE-RNN | NCDE | LMU | HiPPO-LegS |
---|---|---|---|---|---|---|
100Hz -> 200Hz | 25.4 | 23.1 | 41.8 | 44.7 | 6.0 | 88.8 |
200Hz -> 100Hz | 64.6 | 25.5 | 31.5 | 11.3 | 13.1 | 90.1 |
Conclusion 结论
-
The problem of maintaining memory representations of sequential data can be tackled by posing and solving continuous-time formalisms.
维护序列数据内存表示的问题可以通过提出和解决连续时间形式化方法来解决。 -
The HiPPO framework explains several previous sequence models as well as produces new models with cool properties.
HiPPO 框架解释了之前的几个序列模型,并产生了具有良好特性的新模型。 -
This is just the tip of the iceberg - there are many technical extensions of HiPPO, rich connections to other sequence models, and potential applications waiting to be explored!
这只是冰山一角——HiPPO 还有许多技术扩展、与其他序列模型的丰富联系以及有待探索的潜在应用!
Try it out 试试看
PyTorch and Tensorflow code for HiPPO are available on GitHub, where the HiPPO-RNNs can be used as a drop-in replacement for most RNN-based models. Closed-form formulas and implementations are given for the HiPPO instantiations mentioned here, and several more. For more details, see the full paper.
PyTorch 和 Tensorflow 版本的 HiPPO 代码已在 GitHub 上发布,其中 HiPPO-RNN 可以作为大多数基于 RNN 模型的替代品。这里提到的 HiPPO 实例以及更多实例的封闭形式公式和实现都已给出。更多详情,请参阅全文。
Footnotes 脚注
- A measure induces a Hilbert space structure on the space of functions, so that there is a unique optimal approximation - the projection onto the desired subspace.↩
一个测度在函数空间上诱导出一个希尔伯特空间结构,从而存在唯一的最佳逼近——到所需子空间的投影。 - Examples of famous orthogonal polynomials include the Chebyshev polynomials and Legendre polynomials. The names of our methods, such as LegS (scaled Legendre), are based off the orthogonal polynomial family corresponding to its measure.↩
著名的正交多项式例子包括切比雪夫多项式和勒让德多项式。我们方法的名称,例如 LegS(缩放勒让德),是基于与其测度对应的正交多项式族。 - The way we integrate the HiPPO recurrence into an RNN is slightly different, so the full RNN versions of the HiPPO-LegT and Legendre Memory Unit (LMU) are slightly different, but the core linear recurrence is the same.↩
我们将 HiPPO 递归整合到 RNN 中的方式略有不同,因此 HiPPO-LegT 和 Legendre Memory Unit (LMU)的完整 RNN 版本略有不同,但核心线性递归相同。