这是用户在 2024-9-8 13:30 为 https://ar5iv.labs.arxiv.org/html/2102.12092?_immersive_translate_auto_translate=1 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?

Zero-Shot Text-to-Image Generation
零样本文本到图像生成

Aditya Ramesh  阿迪提亚·拉梅什    Mikhail Pavlov  米哈伊尔·帕夫洛夫    Gabriel Goh  高嘉乐    Scott Gray  斯科特·格雷    Chelsea Voss  切尔西·沃斯    Alec Radford  亚历克·拉德福德    Mark Chen  马克·陈    Ilya Sutskever  伊利亚·苏茨克维尔
Abstract 摘要

Text-to-image generation has traditionally focused on finding better modeling assumptions for training on a fixed dataset. These assumptions might involve complex architectures, auxiliary losses, or side information such as object part labels or segmentation masks supplied during training. We describe a simple approach for this task based on a transformer that autoregressively models the text and image tokens as a single stream of data. With sufficient data and scale, our approach is competitive with previous domain-specific models when evaluated in a zero-shot fashion.
文本生成图像的传统方法主要集中在为固定数据集训练寻找更好的建模假设。这些假设可能涉及复杂的架构、辅助损失或在训练期间提供的对象部分标签或分割掩码等侧信息。我们描述了一种基于变压器的简单方法,该方法自回归地将文本和图像标记建模为一条数据流。在足够的数据和规模下,我们的方法在零样本评估中与之前的特定领域模型具有竞争力。

Machine Learning, ICML  机器学习,ICML

1 Introduction
1 介绍

Modern machine learning approaches to text to image synthesis started with the work of Mansimov et al. (2015), who showed that the DRAW Gregor et al. (2015) generative model, when extended to condition on image captions, could also generate novel visual scenes. Reed et al. (2016b) later demonstrated that using a generative adversarial network (Goodfellow et al., 2014), rather than a recurrent variational auto-encoder, improved image fidelity. Reed et al. (2016b) showed that this system could not only generate objects with recognizable properties, but also could zero-shot generalize to held-out categories.
现代机器学习方法在文本到图像合成方面的研究始于Mansimov 等人(2015的工作,他们展示了 DRAWGregor 等人(2015生成模型在扩展到图像标题的条件时,也能够生成新颖的视觉场景。Reed 等人(2016b随后证明,使用生成对抗网络(Goodfellow 等人,2014而不是递归变分自编码器,提高了图像的保真度。Reed 等人(2016b表明,这一系统不仅可以生成具有可识别特性的物体,还能够零样本泛化到未见过的类别。

Over the next few years, progress continued using a combination of methods. These include improving the generative model architecture with modifications like multi-scale generators (Zhang et al., 2017, 2018), integrating attention and auxiliary losses (Xu et al., 2018), and leveraging additional sources of conditioning information beyond just text (Reed et al., 2016a; Li et al., 2019; Koh et al., 2021).
在接下来的几年里,进展持续通过多种方法的结合实现。这些方法包括改进生成模型架构,例如使用多尺度生成器(Zhang et al., 2017, 2018),整合注意力和辅助损失(Xu et al., 2018),以及利用除文本之外的其他条件信息来源(Reed et al., 2016a; Li et al., 2019; Koh et al., 2021)

Refer to caption
Figure 1: Comparison of original images (top) and reconstructions from the discrete VAE (bottom). The encoder downsamples the spatial resolution by a factor of 8. While details (e.g., the texture of the cat’s fur, the writing on the storefront, and the thin lines in the illustration) are sometimes lost or distorted, the main features of the image are still typically recognizable. We use a large vocabulary size of 8192 to mitigate the loss of information.
图 1:原始图像(上)与离散变分自编码器重建图像(下)的比较。编码器将空间分辨率降低了 8 倍。虽然细节(例如,猫毛的纹理、店面上的文字以及插图中的细线)有时会丢失或失真,但图像的主要特征通常仍可识别。我们使用 8192 的大词汇量来减轻信息损失。

Separately, Nguyen et al. (2017) propose an energy-based framework for conditional image generation that obtained a large improvement in sample quality relative to contemporary methods. Their approach can incorporate pretrained discriminative models, and they show that it is capable of performing text-to-image generation when applied to a captioning model pretrained on MS-COCO. More recently, Cho et al. (2020) also propose a method that involves optimizing the input to a pretrained cross-modal masked language model. While significant increases in visual fidelity have occurred as a result of the work since Mansimov et al. (2015), samples can still suffer from severe artifacts such as object distortion, illogical object placement, or unnatural blending of foreground and background elements.
另外,阮等人(2017 提出了一个基于能量的条件图像生成框架,相较于当代方法在样本质量上取得了巨大提升。他们的方法可以结合预训练的判别模型,并且他们展示了当应用于在 MS-COCO 上预训练的图像描述模型时,能够进行文本到图像的生成。最近,Cho 等人(2020 也提出了一种涉及优化输入到预训练的跨模态掩蔽语言模型的方法。尽管自 Mansimov 等人(2015 以来,视觉逼真度有了显著提高,但样本仍可能遭受严重的伪影,例如物体扭曲、不合逻辑的物体位置,或前景和背景元素的不自然融合。

Recent advances fueled by large-scale generative models suggest a possible route for further improvements. Specifically, when compute, model size, and data are scaled carefully, autoregressive transformers (Vaswani et al., 2017) have achieved impressive results in several domains such as text (Radford et al., 2019), images (Chen et al., 2020), and audio (Dhariwal et al., 2020).
最近的进展得益于大规模生成模型,暗示了进一步改进的可能途径。具体而言,当计算、模型规模和数据得到谨慎扩展时,自回归变换器 (Vaswani et al., 2017) 在文本 (Radford et al., 2019)、图像 (Chen et al., 2020) 和音频 (Dhariwal et al., 2020) 等多个领域取得了令人印象深刻的成果。

Refer to caption Refer to caption
Refer to caption Refer to caption
(a) a tapir made of accordion. a tapir with the texture of an accordion.
(a) 由手风琴制成的貘。一只具有手风琴质感的貘。
Refer to caption Refer to caption
Refer to caption Refer to caption
(b) an illustration of a baby hedgehog in a christmas sweater walking a dog
(b) 一只穿着圣诞毛衣的小刺猬在遛狗的插图
Refer to caption Refer to caption
Refer to caption Refer to caption
(c) a neon sign that reads “backprop”. a neon sign that reads “backprop”. backprop neon sign
(c) 一个霓虹灯招牌写着“backprop”。一个霓虹灯招牌写着“backprop”。backprop 霓虹灯招牌
Refer to caption Refer to caption
Refer to caption Refer to caption
(d) the exact same cat on the top as a sketch on the bottom
(d) 顶部和底部草图上是完全一样的猫
Figure 2: With varying degrees of reliability, our model appears to be able to combine distinct concepts in plausible ways, create anthropomorphized versions of animals, render text, and perform some types of image-to-image translation.
图 2: 我们的模型似乎能够以不同程度的可靠性将不同的概念以合理的方式组合起来,创建动物的人形版本,渲染文本,并执行某些类型的图像到图像的转换。

By comparison, text-to-image generation has typically been evaluated on relatively small datasets such as MS-COCO and CUB-200 (Welinder et al., 2010). Could dataset size and model size be the limiting factor of current approaches? In this work, we demonstrate that training a 12-billion parameter autoregressive transformer on 250 million image-text pairs collected from the internet results in a flexible, high fidelity generative model of images controllable through natural language.
通过比较,文本生成图像通常是在相对较小的数据集上进行评估,例如 MS-COCO 和 CUB-200 (Welinder et al., 2010)。数据集的大小和模型的大小是否是当前方法的限制因素?在这项工作中,我们证明了在从互联网收集的 2.5 亿图像-文本对上训练一个 120 亿参数的自回归变换器,会产生一个可通过自然语言控制的灵活、高保真的图像生成模型。

The resulting system achieves high quality image generation on the popular MS-COCO dataset zero-shot, without using any of the training labels. It is preferred over prior work trained on the dataset by human evaluators 90% of the time. We also find that it is able to perform complex tasks such as image-to-image translation at a rudimentary level. This previously required custom approaches (Isola et al., 2017), rather emerging as a capability of a single, large generative model.
生成的系统在流行的 MS-COCO 数据集 零样本 上实现了高质量的图像生成,而无需使用任何训练标签。在人类评估者中,它在 90%的时间里被认为优于之前在该数据集上训练的工作。我们还发现,它能够在基础水平上执行复杂任务,如图像到图像的转换。这在以前需要定制的方法 (Isola et al., 2017),而现在则成为单个大型生成模型的一种能力。

Refer to caption
Figure 3: Comparison of samples from our model to those from prior approaches on captions from MS-COCO. Each of our model samples is the best of 512 as ranked by the contrastive model. We do not use any manual cherrypicking with the selection of either the captions or the samples from any of the models.
图 3: 比较了我们模型的样本与先前方法在 MS-COCO 标题上的样本。我们模型的每个样本都是 512 个样本中由对比模型排名最高的。在选择标题或任何模型的样本时,我们没有使用任何手动挑选。

2 Method
2 方法

Our goal is to train a transformer (Vaswani et al., 2017) to autoregressively model the text and image tokens as a single stream of data. However, using pixels directly as image tokens would require an inordinate amount of memory for high-resolution images. Likelihood objectives tend to prioritize modeling short-range dependencies between pixels (Salimans et al., 2017), so much of the modeling capacity would be spent capturing high-frequency details instead of the low-frequency structure that makes objects visually recognizable to us.
我们的目标是训练一个变换器 (Vaswani et al., 2017),以自回归方式将文本和图像令牌建模为单一数据流。然而,直接使用像素作为图像令牌会对高分辨率图像要求过多的内存。似然目标往往优先建模像素之间的短程依赖关系 (Salimans et al., 2017),因此,大部分建模能力将用于捕捉高频细节,而不是使物体对我们视觉上可识别的低频结构。

We address these issues by using a two-stage training procedure, similar to (Oord et al., 2017; Razavi et al., 2019):
我们通过使用两阶段训练程序来解决这些问题,类似于(Oord et al., 2017; Razavi et al., 2019):

  • Stage 1. We train a discrete variational autoencoder (dVAE)111https://github.com/openai/DALL-E
    阶段 1. 我们训练了一个离散变分自编码器 (dVAE)1
    to compress each 256×256256256256\times 256 RGB image into a 32×32323232\times 32 grid of image tokens, each element of which can assume 819281928192 possible values. This reduces the context size of the transformer by a factor of 192192192 without a large degradation in visual quality (see Figure 1).
    将每个 256×256256\times 256 RGB 图像压缩成 32×3232\times 32 图像标记网格,每个元素可以采用 81928192 种可能值。这将变压器的上下文大小降低了 192192 倍,且视觉质量没有大幅下降(见图 1)。

  • Stage 2. We concatenate up to 256 BPE-encoded text tokens with the 32×32=10243232102432\times 32=1024 image tokens, and train an autoregressive transformer to model the joint distribution over the text and image tokens.
    阶段 2. 我们将最多 256 个 BPE 编码的文本标记与 32×32=102432\times 32=1024 个图像标记连接在一起,并训练一个自回归变换器来建模文本和图像标记的联合分布。

The overall procedure can be viewed as maximizing the evidence lower bound (ELB) (Kingma & Welling, 2013; Rezende et al., 2014) on the joint likelihood of the model distribution over images x𝑥x, captions y𝑦y, and the tokens z𝑧z for the encoded RGB image. We model this distribution using the factorization pθ,ψ(x,y,z)=pθ(x|y,z)pψ(y,z)subscript𝑝𝜃𝜓𝑥𝑦𝑧subscript𝑝𝜃conditional𝑥𝑦𝑧subscript𝑝𝜓𝑦𝑧p_{\theta,\psi}(x,y,z)=p_{\theta}(x\,|\,y,z)p_{\psi}(y,z), which yields the lower bound
整体过程可以看作是最大化证据下限(ELB) (Kingma & Welling, 2013; Rezende et al., 2014),该下限针对模型分布在图像 xx 、标题 yy 和编码的 RGB 图像的令牌 zz 的联合似然性。我们使用因式分解 pθ,ψ(x,y,z)=pθ(x|y,z)pψ(y,z)p_{\theta,\psi}(x,y,z)=p_{\theta}(x\,|\,y,z)p_{\psi}(y,z) 对该分布建模,产生下限。

lnpθ,ψ(x,y)𝔼zqϕ(z|x)(lnpθ(x|y,z)βDKL(qϕ(y,z|x),pψ(y,z))),subscript𝑝𝜃𝜓𝑥𝑦subscript𝔼missing-subexpressionsimilar-to𝑧subscript𝑞italic-ϕconditional𝑧𝑥subscript𝑝𝜃|𝑥𝑦𝑧𝛽subscript𝐷KLsubscript𝑞italic-ϕ𝑦|𝑧𝑥subscript𝑝𝜓𝑦𝑧\ln p_{\theta,\psi}(x,y)\geqslant\!\!\!\!\!\!\!\!\mathop{\mathbb{E}}_{\begin{subarray}{c}\vspace{0.1mm}\\ z\sim q_{\phi}(z\,|\,x)\end{subarray}}\!\!\!\!\!\!\!\!\big{(}\ln p_{\theta}(x\,|\,y,z)\;-\\ \beta\,D_{\mathrm{KL}}(q_{\phi}(y,z\,|\,x),p_{\psi}(y,z))\big{)}, (1)

where: 哪里:

  • qϕsubscript𝑞italic-ϕq_{\phi} denotes the distribution over the 32×32323232\times 32 image tokens generated by the dVAE encoder given the RGB image x𝑥x222We assume that y𝑦y is conditionally independent of x𝑥x given z𝑧z.
    我们假设在给定 zz 的情况下, yyxx 条件独立。

    qϕq_{\phi} 表示 dVAE 编码器根据 RGB 图像 xx 生成的 32×3232\times 32 图像标记的分布 2
    ;

  • pθsubscript𝑝𝜃p_{\theta} denotes the distribution over the RGB images generated by the dVAE decoder given the image tokens; and
    pθp_{\theta} 表示在给定图像令牌的情况下,dVAE 解码器生成的 RGB 图像的分布;并且

  • pψsubscript𝑝𝜓p_{\psi} denotes the joint distribution over the text and image tokens modeled by the transformer.
    pψp_{\psi} 表示由 Transformer 建模的文本和图像标记的联合分布。

Note that the bound only holds for β=1𝛽1\beta=1, while in practice we find it helpful to use larger values (Higgins et al., 2016). The following subsections describe both stages in further detail.333In preliminary experiments on ImageNet (Deng et al., 2009), we attempted to maximize the ELB with respect to ϕitalic-ϕ\phi, θ𝜃\theta, and ψ𝜓\psi jointly, but were unable to improve on two-stage training.
在 ImageNet (Deng 等人,2009) 的初步实验中,我们尝试最大化 ELB 相对于 ϕ\phiθ\thetaψ\psi 的联合,但无法改进两阶段训练。

请注意,该约束仅适用于 β=1\beta=1 ,而在实践中,我们发现使用更大的值更有帮助(Higgins 等人,2016)。以下小节将更详细地描述这两个阶段。3

2.1 Stage One: Learning the Visual Codebook
2.1 第一阶段:学习视觉码本

In the first stage of training, we maximize the ELB with respect to ϕitalic-ϕ\phi and θ𝜃\theta, which corresponds to training a dVAE on the images alone. We set the initial prior pψsubscript𝑝𝜓p_{\psi} to the uniform categorical distribution over the K=8192𝐾8192K=$8192$ codebook vectors, and qϕsubscript𝑞italic-ϕq_{\phi} to be categorical distributions parameterized by the 819281928192 logits at the same spatial position in the 32×32323232\times 32 grid output by the encoder.
在训练的第一阶段,我们最大化 ELB 关于 ϕ\phiθ\theta ,这对应于仅对图像训练 dVAE。我们将初始先验 pψp_{\psi} 设置为在 K=8192K=$8192$ 代码本向量上的均匀分类分布,并将 qϕq_{\phi} 设置为由编码器输出的 32×3232\times 32 网格中相同空间位置的 81928192 logits 参数化的分类分布。

The ELB now becomes difficult to optimize: as qψsubscript𝑞𝜓q_{\psi} is a discrete distribution, and we cannot use the reparameterization gradient to maximize it. Oord et al. (2017); Razavi et al. (2019) address this using an online cluster assignment procedure coupled with the straight-through estimator (Bengio et al., 2013). We instead use the gumbel-softmax relaxation (Jang et al., 2016; Maddison et al., 2016), replacing the expectation over qϕsubscript𝑞italic-ϕq_{\phi} with one over qϕτsubscriptsuperscript𝑞𝜏italic-ϕq^{\tau}_{\phi}, where the relaxation becomes tight as the temperature τ0𝜏0\tau\to 0. The likelihood for pθsubscript𝑝𝜃p_{\theta} is evaluated using the log-laplace distribution (see Appendix A.3 for a derivation).
ELB 现在变得难以优化:因为 qψq_{\psi} 是离散分布,我们无法使用重参数化梯度来最大化它。Oord 等人(2017);Razavi 等人(2019通过使用在线集群分配程序以及直通估计器来解决这个问题(Bengio 等人,2013)。我们则使用 gumbel-softmax 松弛(Jang 等人,2016; Maddison 等人,2016),用 qϕq_{\phi} 的期望替换为 qϕτq^{\tau}_{\phi} 的期望,其中松弛在温度 τ0\tau\to 0 时变得紧致。 pθp_{\theta} 的似然使用对数拉普拉斯分布进行评估(详见附录A.3的推导)。

The relaxed ELB is maximized using Adam (Kingma & Ba, 2014) with exponentially weighted iterate averaging. Appendix A.2 gives a complete description of the hyperparameters, but we found the following to be especially important for stable training:
使用 Adam (Kingma & Ba, 2014) 和指数加权迭代平均来最大化松弛 ELB。附录 A.2 给出了超参数的完整描述,但我们发现以下内容对于稳定训练尤其重要:

  • Specific annealing schedules for the relaxation temperature and step size. We found that annealing τ𝜏\tau to 1/161161/16 was sufficient to close the gap between the relaxed validation ELB and the true validation ELB with qϕsubscript𝑞italic-ϕq_{\phi} intsead of qϕτsuperscriptsubscript𝑞italic-ϕ𝜏q_{\phi}^{\tau}.
    特定的退火时间表用于松弛温度和步长。我们发现将退火 τ\tau1/161/16 足以缩小松弛验证 ELB 与真实验证 ELB 之间的差距,使用 qϕq_{\phi} 而非 qϕτq_{\phi}^{\tau}

  • The use of 1×1111\times 1 convolutions at the end of the encoder and the beginning of the decoder. We found that reducing the receptive field size for the convolutions around the relaxation led to it generalizing better to the true ELB.
    在编码器末端和解码器开头使用 1×11\times 1 卷积。我们发现,缩小围绕松弛的卷积的感受野大小有助于它更好地泛化到真实的 ELB。

  • Multiplication of the outgoing activations from the encoder and decoder resblocks by a small constant, to ensure stable training at initialization.
    将编码器和解码器残差块的输出激活乘以一个小常数,以确保在初始化时训练稳定。

We also found that increasing the KL weight to β=6.6𝛽6.6\beta=6.6 promotes better codebook usage and ultimately leads to a smaller reconstruction error at the end of training.444This is contrary to the usual tradeoff between the two terms. We speculate that for smaller values of β𝛽\beta, the noise from the relaxation causes the optimizer to reduce codebook usage toward the beginning of training, resulting in worse ELB at convergence.
这与这两个术语之间通常的权衡相反。我们推测,对于较小的 β\beta 值,来自松弛的噪声会导致优化器在训练开始时减少码本的使用,从而导致收敛时 ELB 较差。

我们还发现,将 KL 权重提高到 β=6.6\beta=6.6 可以促进更好的字典使用,并最终在训练结束时导致更小的重建误差。4

2.2 Stage Two: Learning the Prior
2.2 第二阶段:学习先验

In the second stage, we fix ϕitalic-ϕ\phi and θ𝜃\theta, and learn the prior distribution over the text and image tokens by maximizing the ELB with respect to ψ𝜓\psi. Here, pψsubscript𝑝𝜓p_{\psi} is represented by a 12-billion parameter sparse transformer (Child et al., 2019).
在第二阶段,我们固定 ϕ\phiθ\theta ,通过最大化 ELB 以学习文本和图像标记的先验分布,关于 ψ\psi 。在这里, pψp_{\psi} 由一个具有 120 亿参数的稀疏变换器表示 (Child et al., 2019)

Given a text-image pair, we BPE-encode (Sennrich et al., 2015) the lowercased caption using at most 256 tokens555During training, we apply 10% BPE dropout (Provilkov et al., 2019), whose use is common in the neural machine translation literature.
在训练过程中,我们应用 10%的 BPE dropout (Provilkov et al., 2019),这种方法在神经机器翻译文献中很常见。

给定一个文本-图像对,我们使用 BPE 编码(Sennrich 等人,2015)对小写标题进行编码,最多使用 256 个标记。
with vocabulary size 163841638416384, and encode the image using 32×32=10243232102432\times 32=1024 tokens with vocabulary size 819281928192. The image tokens are obtained using argmax sampling from the dVAE encoder logits, without adding any gumbel noise.666Strictly speaking, Equation 1 requires us to sample from the categorical distribution specified by the dVAE encoder logits, rather than taking the argmax. In preliminary experiments on ImageNet, we found that this was a useful regularizer in the overparameterized regime, and allows the transformer to be trained using soft targets for the cross-entropy loss. We decided against this here since the model in consideration is in the underparameterized regime.
严格来说,公式 1 要求我们从 dVAE 编码器 logits 指定的类别分布中进行采样,而不是取 argmax。在 ImageNet 上的初步实验中,我们发现这在过度参数化的情况下是一个有用的正则化器,并且允许使用交叉熵损失的软目标来训练 Transformer。我们在这里放弃了这种方法,因为所考虑的模型处于欠参数化状态。

词汇大小为 1638416384 ,并使用 32×32=102432\times 32=1024 个词元对图像进行编码,词汇大小为 81928192 。图像词元是通过从 dVAE 编码器的 logits 中进行 argmax 采样获得的,没有添加任何 gumbel 噪声。6
Finally, the text and image tokens are concatenated and modeled autoregressively as a single stream of data.
最后,文本和图像标记被串联并作为一条数据流进行自回归建模。

The transformer is a decoder-only model in which each image token can attend to all text tokens in any one of its 64 self-attention layers. The full architecture is described in Appendix B.1. There are three different kinds of self-attention masks used in the model. The part of the attention masks corresponding to the text-to-text attention is the standard causal mask, and the part for the image-to-image attention uses either a row, column, or convolutional attention mask.777We found using a single attention operation for all three interactions – “text attends to text”, “image attends to text”, and “image attends to image” – to perform better than using separate attention operations that are independently normalized.
我们发现对所有三种交互使用单一的注意力操作——“文本关注文本”,“图像关注文本”和“图像关注图像”——比使用独立归一化的单独注意力操作效果更好。

该变换器是一个仅解码的模型,每个图像标记可以在其 64 个自注意层中的任何一个层中关注所有文本标记。完整的架构在附录 B.1 中描述。该模型使用三种不同类型的自注意力掩码。对应于文本到文本注意力的掩码部分是标准的因果掩码,而图像到图像注意力的掩码部分使用行、列或卷积注意力掩码。

We limit the length of a text caption to 256 tokens, though it is not totally clear what to do for the “padding” positions in between the last text token and the start-of-image token. One option is to set the logits for these tokens to -\infty in the self-attention operations. Instead, we opt to learn a special padding token separately for each of the 256 text positions. This token is used only when no text token is available. In preliminary experiments on Conceptual Captions (Sharma et al., 2018), we found that this resulted in higher validation loss, but better performance on out-of-distribution captions.
我们限制文本标题的长度为 256 个标记,但对于最后一个文本标记和图像开始标记之间的“填充”位置应该怎么做还不完全清楚。一个选择是在自注意力操作中将这些标记的 logits 设置为 -\infty 。相反,我们选择为 256 个文本位置中的每一个单独学习一个特殊的填充标记。此标记仅在没有文本标记可用时使用。在对概念字幕 (Sharma 等人,2018) 的初步实验中,我们发现这会导致更高的验证损失,但在分布外字幕上表现更好。

We normalize the cross-entropy losses for the text and image tokens by the total number of each kind in a batch of data. Since we are primarily interested in image modeling, we multiply the cross-entropy loss for the text by 1/8181/8 and the cross-entropy loss for the image by 7/8787/8. The objective is optimized using Adam with exponentially weighted iterate averaging; Appendix B.2 describes the training procedure in more detail. We reserved about 606000606000606000 images for validation, and found no signs of overfitting at convergence.
我们通过每批数据中每种类型的总数来规范化文本和图像标记的交叉熵损失。由于我们主要关注图像建模,我们将文本的交叉熵损失乘以  1/81/8 ,将图像的交叉熵损失乘以  7/87/8 。优化目标使用 Adam 进行指数加权迭代平均;附录 B.2 详细描述了训练过程。我们保留了大约  606000606000 张图像用于验证,并在收敛时没有发现过拟合的迹象。

Refer to caption
Figure 4: Illustration of per-resblock gradient scaling for a transformer resblock. The solid line indicates the sequence of operations for forward propagation, and the dashed line the sequence of operations for backpropagation. We scale the incoming gradient for each resblock by its gradient scale, and unscale the outgoing gradient before it is added to the sum of the gradients from the successive resblocks. The activations and gradients along the identity path are stored in 32-bit precision. The “filter” operation sets all Inf and NaN values in the activation gradient to zero. Without this, a nonfinite event in the current resblock would cause the gradient scales for all preceding resblocks to unnecessarily drop, thereby resulting in underflow.
图 4:变压器残差块的每个残差块梯度缩放的示意图。实线表示前向传播的操作顺序,虚线表示反向传播的操作顺序。我们通过每个残差块的梯度缩放来缩放传入的梯度,并在将其添加到后续残差块的梯度总和之前对输出梯度进行反缩放。身份路径上的激活值和梯度以 32 位精度存储。“过滤”操作将激活梯度中的所有 Inf 和 NaN 值设置为零。如果没有这个,当前残差块的非有限事件将导致所有前面的残差块的梯度缩放不必要地下降,从而导致下溢。
Refer to caption
Figure 5: Communication patterns used for distributed training. Each parameter array in the model is sharded among the eight GPUs on each machine. During forward propagation, we prefetch the parameter shards for the next resblock (using all-gather) while computing the activations for the current resblock. To conserve memory, the parameter shards from the other GPUs are immediately discarded. Similarly, during backpropagation, we prefetch the parameter shards for the previous resblock while computing the activations and gradients for the current resblock. After all GPUs have computed the gradient with respect to an all-gathered parameter, the reduce-scatter operation leaves each GPU with only one slice – i.e., the gradient for its parameter shard, averaged over the eight GPUs.
图 5: 用于分布式训练的通信模式。模型中的每个参数数组在每台机器上的八个 GPU 之间进行分片。在正向传播期间,我们在计算当前残差块的激活时,预取下一个残差块的参数分片(使用 all-gather)。为了节省内存,来自其他 GPU 的参数分片会立即丢弃。类似地,在反向传播期间,我们在计算当前残差块的激活和梯度时,预取前一个残差块的参数分片。在所有 GPU 都计算了关于所有收集参数的梯度之后,reduce-scatter 操作使每个 GPU 仅保留一个切片 - 即其参数分片的梯度,在八个 GPU 上取平均值。

2.3 Data Collection
2.3 数据收集

Our preliminary experiments for models up to 1.21.21.2 billion parameters were carried out on Conceptual Captions, a dataset of 3.3 million text-image pairs that was developed as an extension to MS-COCO (Lin et al., 2014).
我们对参数数量达到 1.21.2 亿的模型的初步实验是在 Conceptual Captions 数据集上进行的,该数据集包含 330 万个文本图像对,是 MS-COCO 的扩展 (Lin et al., 2014)

To scale up to 121212-billion parameters, we created a dataset of a similar scale to JFT-300M (Sun et al., 2017) by collecting 250 million text-images pairs from the internet. This dataset does not include MS-COCO, but does include Conceptual Captions and a filtered subset of YFCC100M (Thomee et al., 2016). As MS-COCO was created from the latter, our training data includes a fraction of the MS-COCO validation images (but none of the captions). We control for this in the quantitative results presented in Section 3 and find that it has no appreciable bearing on the results. We provide further details about the data collection process in Appendix C.
为了扩展到 1212 十亿个参数,我们创建了一个与 JFT-300M 规模类似的数据集(Sun 等人,2017),从互联网上收集了 2.5 亿个文本-图像对。此数据集不包括 MS-COCO,但包含概念字幕和 YFCC100M 的过滤子集(Thomee 等人,2016)。由于 MS-COCO 是从后者创建的,因此我们的训练数据包含一小部分 MS-COCO 验证图像(但没有字幕)。我们在第3 节中介绍的定量结果中对此进行了控制,发现它对结果没有明显影响。我们在附录C 中提供了有关数据收集过程的更多详细信息。

2.4 Mixed-Precision Training
2.4 混合精度训练

To save GPU memory and increase throughput, most parameters, Adam moments, and activations are stored in 16-bit precision. We also use activation checkpointing and recompute the activations within the resblocks during the backward pass. Getting the model to train in 16-bit precision past one billion parameters, without diverging, was the most challenging part of this project.
为了节省 GPU 内存并提高吞吐量,大多数参数、Adam 动量和激活都以 16 位精度存储。我们还使用激活检查点,并在反向传播过程中重新计算 resblocks 中的激活。让模型在超过 10 亿个参数的情况下以 16 位精度进行训练,而不会发散,是这个项目中最具挑战性的部分。

We believe the root cause of this instability to be underflow in the 16-bit gradients. Appendix D presents a set of guidelines we developed to avoid underflow when training large-scale generative models. Here, we describe one of these guidelines: per-resblock gradient scaling.
我们认为这种不稳定性的根本原因是 16 位梯度下溢。附录 D 提供了一套我们为避免在训练大型生成模型时出现下溢而制定的指南。这里,我们描述了其中一项指南:每个残差块梯度缩放。

Similar to prior work (Liu et al., 2020), we found that the norms of the activation gradients from the resblocks decrease monotonically as we move from the earlier resblocks to the later ones.888It is possible that better initialization schemes (Liu et al., 2020) might be able to avoid this, but we did not have success with alternative schemes in our experiments.
有可能更好的初始化方案(Liu et al., 2020)能够避免这个问题,但在我们的实验中,替代方案并没有取得成功。

与先前的工作(Liu 等人,2020)类似,我们发现,从较早的残差块到较晚的残差块,激活梯度的范数单调递减。8
As the model is made deeper and wider, the true exponents of the activation gradients for later resblocks can fall below the minimum exponent of the 16-bit format. Consequently, they get rounded to zero, a phenomenon called underflow. We found that eliminating underflow allowed for stable training to convergence.
随着模型的加深和扩展,后续残差块的激活梯度的真实指数可能会低于 16 位格式的最小指数。因此,它们会被舍入为零,这种现象称为下溢。我们发现消除下溢使得训练能够稳定收敛。

Standard loss scaling (Micikevicius et al., 2017) is able to avoid underflow when the range spanned by the smallest and largest activation gradients (in absolute value) fits within the exponent range of the 16-bit format. On NVIDIA V100 GPUs, this exponent range is specified by five bits. While this is sufficient for training vanilla language models of the same size, we found the range to be too small for the text-to-image model.
标准损失缩放 (Micikevicius 等人,2017) 能够避免下溢,当最小和最大激活梯度(绝对值)跨越的范围适合于 16 位格式的指数范围时。在 NVIDIA V100 GPU 上,此指数范围由五个位指定。虽然这足以训练相同大小的普通语言模型,但我们发现该范围对于文本到图像模型来说太小了。

Our fix, which is shown in Figure 4, involves using a separate “gradient scale” for each resblock in the model. This can be seen as a practical alternative to a more general framework for mixed-precision training called Flexpoint (Köster et al., 2017), with the advantage that specialized GPU kernels are not required. We found that Sun et al. (2020) had independently developed similar procedure for training convolutional networks in 4-bit precision.
我们的修复方法如图 4 所示,涉及为模型中的每个残差块使用单独的“梯度尺度”。这可以视为一种实用的替代方案,替代一种称为 Flexpoint 的更通用的混合精度训练框架 (Köster et al., 2017),其优点是不需要专用的 GPU 内核。我们发现 Sun et al. (2020) 独立开发了类似的程序,用于以 4 位精度训练卷积网络。

2.5 Distributed Optimization
2.5 分布式优化

Effective Parameter Count
有效参数数量
Compression Rank 压缩排名 Compression Rate 压缩率
2.81092.8superscript1092.8\cdot 10^{9} (dmodel=1920subscript𝑑model1920d_{\mathrm{model}}=1920) 512 83%absentpercent83\approx\!83\%
5.61095.6superscript1095.6\cdot 10^{9} (dmodel=2688subscript𝑑model2688d_{\mathrm{model}}=2688) 640 85%absentpercent85\approx\!85\%
12.010912.0superscript10912.0\cdot 10^{9} (dmodel=3968subscript𝑑model3968d_{\mathrm{model}}=3968) 896 86%absentpercent86\approx\!86\%
Table 1: We show the relationship between model size and the minimum compression rank for the gradients (up to a multiple of 128) necessary to avoid a gap in the training loss during the first 10%percent1010\% of training. These results suggest that in our setting, we can achieve a compression rate of about 85%percent8585\%, independent of model size.
表 1: 我们展示了模型大小与避免训练损失在前 10%10\% 次训练期间出现间隙所需的梯度最小压缩秩(最多为 128 的倍数)之间的关系。这些结果表明,在我们的设定中,我们可以实现约 85%85\% 的压缩率,与模型大小无关。
Refer to caption
Figure 6: Effect of increasing the number of images for the contrastive reranking procedure on MS-COCO captions.
图 6: 增加图像数量对 MS-COCO 字幕的对比重排序过程的影响。

Our 12-billion parameter model consumes about 24 GB of memory when stored in 16-bit precision, which exceeds the memory of a 16 GB NVIDIA V100 GPU. We address this using parameter sharding (Rajbhandari et al., 2019). As shown in Figure 5, parameter sharding allows us to almost completely hide the latency of the intra-machine communication by overlapping it with compute-intensive operations.
我们 120 亿参数的模型以 16 位精度存储时占用约 24 GB 内存,超过了 16 GB NVIDIA V100 GPU 的内存。我们使用参数分片 (Rajbhandari 等人,2019) 来解决这个问题。如图 5 所示,参数分片使我们能够通过将其与计算密集型操作重叠来几乎完全隐藏机器内通信的延迟。

On the cluster used to train the model, the bandwidth between machines is much lower than the bandwidth among GPUs on the same machine. This makes the cost of the operation used to average the gradient among the machines (all-reduce) the main bottleneck during training. We were able to drastically reduce this cost by compressing the gradients using PowerSGD (Vogels et al., 2019).
在用于训练模型的集群中,机器之间的带宽远低于同一台机器上 GPU 之间的带宽。这使得用于在机器之间平均梯度(全减少)的操作成本成为训练过程中的主要瓶颈。我们通过使用 PowerSGD (Vogels et al., 2019)压缩梯度,能够大幅降低这种成本。

In our implementation, each GPU in a machine computes the low-rank factors for its parameter shard gradients independently of its neighboring GPUs.999There is still intra-machine communication for other operations; what we mean is that the low-rank factors across the shards, when concatenated, are not regarded as collectively approximating the gradient for the full parameter matrix.
仍然存在其他操作的机器内部通信;我们所指的是,当拼接切片之间的低秩因素时,并不被视为共同近似完整参数矩阵的梯度。

在我们的实现中,机器中的每个 GPU 独立于其相邻的 GPU 计算其参数分片梯度的低秩因子。
Once the low-rank factors are computed, each machine sets its error buffer to the residual between the uncompressed gradient averaged over its eight GPUs (obtained from reduce-scatter), and the decompressed gradient obtained from the low-rank factors.
一旦计算出低秩因子,每台机器都会将其误差缓冲区设置为未压缩梯度(通过其八个 GPU 的 reduce-scatter 获得)与从低秩因子获得的解压缩梯度之间的残差。

PowerSGD replaces the large communication operation for an uncompressed parameter gradient with two, much smaller communication operations for its low-rank factors. For a given compression rank r𝑟r and transformer activation size dmodelsubscript𝑑modeld_{\mathrm{model}}, the compression rate is given by 15r/(8dmodel)15𝑟8subscript𝑑model1-5r/(8d_{\textrm{model}}) (see Appendix E.1). Table 1 shows that we can achieve a compression rate of about 85%percent8585\%, independent of model size.
PowerSGD 用两个更小的通信操作替代了未压缩参数梯度的大型通信操作,这两个操作用于其低秩因子。对于给定的压缩秩 rr 和变换器激活大小 dmodeld_{\mathrm{model}} ,压缩比由 15r/(8dmodel)1-5r/(8d_{\textrm{model}}) 给出(见附录 E.1)。表 1 显示我们可以实现约 85%85\% 的压缩比,与模型大小无关。

In Appendix E.2, we describe various details that were necessary to get PowerSGD to perform well at scale. These include:
在附录 E.2 中,我们描述了为了使 PowerSGD 在大规模下表现良好所需的各种细节。这些包括:

  • Saving memory by accumulating the gradient into the error buffers during backpropagation, rather than allocating separate buffers.
    在反向传播期间将梯度累积到误差缓冲区中,而不是分配单独的缓冲区,从而节省内存。

  • Minimizing instances in which we zero out the error buffers (e.g., due to nonfinite values encountered during mixed-precision backpropagation, or when resuming training from a checkpoint).
    尽量减少将错误缓冲区清零的情况(例如,由于混合精度反向传播过程中遇到非有限值,或从检查点恢复训练时)。

  • Improving numerical stability by using Householder orthogonalization instead of Gram-Schmidt, together with the addition of a small multiple of the identity matrix to the input.
    通过使用豪斯霍尔德正交化代替 Gram-Schmidt 方法,并向输入中添加一个小的单位矩阵的倍数,从而提高数值稳定性。

  • Avoiding underflow by using a custom 16-bit floating point format for the error buffers, their low-rank factors, and the all-reduce communication operations involving them.
    通过为错误缓冲区、它们的低秩因子以及涉及它们的全规约通信操作使用自定义的 16 位浮点格式来避免下溢。

We also found the warm-start procedure for the Q𝑄Q matrix described in Vogels et al. (2019) to be unnecessary: we were able to get equivalent results by fixing Q𝑄Q to a random gaussian matrix at the start of training, and never updating it.101010We verified that the error in reconstructing the true gradient is higher when Q𝑄Q is fixed as opposed to being updated using warm-starting, so it is interesting that this does not affect the loss. By contrast, resampling Q𝑄Q at every update causes a large performance hit.
我们验证了在重建真实梯度时,当 QQ 固定时,错误率高于使用热启动更新时,因此有趣的是,这不会影响损失。相比之下,在每次更新时对 QQ 进行重采样会导致性能大幅下降。

我们还发现 Vogels 等人(2019)所描述的 QQ 矩阵的热启动过程是多余的:我们通过在训练开始时将 QQ 固定为一个随机高斯矩阵,并且从未更新它,得到了等效的结果。

2.6 Sample Generation
2.6 样本生成

Similar to Razavi et al. (2019), we rerank the samples drawn from the transformer using a pretrained contrastive model (Radford et al., 2021). Given a caption and a candidate image, the contrastive model assigns a score based on how well the image matches the caption. Figure 6 shows the effect of increasing the number of samples N𝑁N from which we select the top k𝑘k images. This process can be seen as a kind of language-guided search (Andreas et al., 2017), and is also similar to the auxiliary text-image matching loss proposed by Xu et al. (2018). Unless otherwise stated, all samples used for both qualitative and quantitative results are obtained without temperature reduction (i.e., using t=1𝑡1t=1) (except for Figure 2) and use reranking with N=512𝑁512N=512.
Razavi 等人 (2019) 类似,我们使用预训练的对比模型 (Radford 等人,2021) 对从 Transformer 中抽取的样本进行重新排序。给定一个标题和一个候选图像,对比模型根据图像与标题匹配程度分配一个分数。图 6 显示了增加样本数量 NN 的影响,我们从中选择前 kk 张图像。此过程可以看作是一种语言引导搜索 (Andreas 等人,2017),并且类似于 Xu 等人 (2018) 提出的辅助文本图像匹配损失。除非另有说明,所有用于定性和定量结果的样本均在不进行温度降低的情况下获得(即使用 t=1t=1 )(图 2 除外),并使用 N=512N=512 进行重新排序。

3 Experiments
3 实验

Refer to caption
Figure 7: Human evaluation of our model (evaluated zero-shot without temperature reduction) vs prior work (DF-GAN) on captions from MS-COCO. In a best-of-five vote, our model’s sample was chosen as the most realistic 90.0% of the time, and was chosen as the image best matching a shared caption 93.3% of the time.
图 7: 我们模型(在没有降温的情况下进行零样本评估)与先前工作(DF-GAN)在 MS-COCO 图像的字幕上的对比。在五选一投票中,我们的模型样本被选为最真实的 90.0%,并且被选为最符合共享字幕的图像 93.3%。

3.1 Quantitative Results
3.1 定量结果

We evaluate our model zero-shot by comparing it to three prior approaches: AttnGAN (Xu et al., 2018), DM-GAN (Zhu et al., 2019), and DF-GAN (Tao et al., 2020), the last of which reports the best Inception Score (Salimans et al., 2016) and Fréchet Inception Distance (Heusel et al., 2017) on MS-COCO. Figure 3 qualitatively compares samples from our model to those from prior work.
我们通过将我们的模型与三种先前的方法进行比较来进行零样本评估:AttnGAN (Xu et al., 2018),DM-GAN (Zhu et al., 2019),和 DF-GAN (Tao et al., 2020),其中最后一种报告了最佳的 Inception Score (Salimans et al., 2016) 和 Fréchet Inception Distance (Heusel et al., 2017) 在 MS-COCO 上。图 3 从定性的角度比较了我们模型的样本与先前工作的样本。

We also conduct a human evaluation similar to the one used in Koh et al. (2021) to compare our approach to DF-GAN, the results of which are shown in Figure 7. Given a caption, the sample from our model receives the majority vote for better matching the caption 93% of the time. It also receives the majority vote for being more realistic 90% of the time.
我们还进行了一项类似于Koh et al. (2021)中使用的人类评估,以将我们的方法与 DF-GAN 进行比较,结果如图7所示。给定一个标题,我们模型生成的样本 93%的情况下获得了大多数投票,认为其与标题更匹配。90%的情况下,它也获得了大多数投票,认为其更真实。

Figure 9(a) shows that our model also obtains an FID score on MS-COCO within 2 points of the best prior approach, despite having never been trained on the captions. Our training data incorporates a filtered subset of YFCC100M, and we found that it includes about 21%percent2121\% of the images in the MS-COCO validation set from a de-duplication procedure described in the next section. To isolate this effect, we compute the FID statistics for the validation set both with these images (solid lines) and without them (dashed lines), finding no significant change in the results.
9(a)显示,我们的模型在 MS-COCO 上获得的 FID 分数也在最佳先前方法的 2 分之内,尽管从未在标题上进行训练。我们的训练数据包含 YFCC100M 的一个过滤子集,我们发现它包括约 21%21\% 的 MS-COCO 验证集中的图像,这些图像是通过下一节描述的去重程序获得的。为了隔离这一影响,我们计算了验证集的 FID 统计数据,包括这些图像(实线)和不包括它们(虚线),发现结果没有显著变化。

Training the transformer on the tokens from the dVAE encoder allows us to allocate its modeling capacity to the low-frequency information that makes images visually recognizable to us. However, it also disadvantages the model, since the heavy compression renders it unable to produce high-frequency details. To test the effect of this on the quantitative evaluations, we compute the FID and IS in Figure 9(a) after applying a Gaussian filter with varying radius to both the validation images and samples from the models. Our approach achieves the best FID by a margin of about 6 points with a slight blur of radius 1. The gap between our approach and others tends to widen as the blur radius is increased. We also obtain the highest IS when the blur radius is greater than or equal to two.
训练变换器使用来自 dVAE 编码器的标记,使我们能够将其建模能力分配给使图像对我们可视化可识别的低频信息。然而,这也使模型处于不利地位,因为重压缩使其无法产生高频细节。为了测试这一点对定量评估的影响,我们计算了图中9(a)的 FID 和 IS,在对验证图像和模型样本应用不同半径的高斯滤波器后。我们的方法在半径为 1 的情况下,比其他方法获得了约 6 分的最佳 FID。随着模糊半径的增加,我们的方法与其他方法之间的差距往往会加大。当模糊半径大于或等于 2 时,我们还获得了最高的 IS。

Refer to caption
Figure 8: Zero-shot samples from our model on the CUB dataset.
图 8: CUB 数据集上我们模型的零样本示例。
Refer to caption
Refer to caption
(a) FID and IS on MS-COCO as a function of blur radius.
(a) FID 和 IS 在 MS-COCO 上作为模糊半径的函数。
Refer to caption
Refer to caption
(b) FID and IS on CUB as a function of blur radius.
(b) CUB 上的 FID 和 IS 作为模糊半径的函数。
Refer to caption
Refer to caption
(c) FID and IS on MS-COCO as a function of the sample size used for reranking.
(c) FID 和 IS 在 MS-COCO 中的重排序样本大小的关系。
Figure 9: Quantitative results on MS-COCO and CUB. Solid lines represent FID computed against the original validation sets, and dashed lines represent FID computed against validation sets with overlapping images removed (see Section 3.2). For MS-COCO, we evaluate all models on a subset of 300003000030000 captions sampled from the validation set. For CUB, we evaluate all models on all of the unique captions in the test set. We compute the FID and IS using the DM-GAN code, which is available at https://github.com/MinfengZhu/DM-GAN.
图 9:MS-COCO 和 CUB 的定量结果。实线表示与原始验证集计算的 FID,虚线表示与去除重复图像的验证集计算的 FID (见第 3.2 节)。对于 MS-COCO,我们在从验证集中采样的 3000030000 个图例的子集上评估所有模型。对于 CUB,我们在测试集中所有独特的图例上评估所有模型。我们使用 DM-GAN 代码计算 FID 和 IS,该代码可在 https://github.com/MinfengZhu/DM-GAN 获得。

Our model fares significantly worse on the CUB dataset, for which there is a nearly 40-point gap in FID between our model and the leading prior approach (Figure 9(b)). We found an 12%percent1212\% overlap rate for this dataset, and again observed no significant difference in the results after removing these images. We speculate that our zero-shot approach is less likely to compare favorably on specialized distributions such as CUB. We believe that fine-tuning is a promising direction for improvement, and leave this investigation to future work. Samples from our model for captions in this dataset are shown in Figure 8.
我们的模型在 CUB 数据集上的表现显著较差,FID 指标在我们的模型与领先的先前方法之间存在近 40 分的差距(图9(b))。我们发现该数据集的 12%12\% 重叠率,并且在去掉这些图像后,结果再次没有显著差异。我们推测我们的零-shot 方法在像 CUB 这样的专业分布上的比较结果不太可能令人满意。我们相信微调是一个有前景的改进方向,并将这一研究留待未来工作处理。我们模型在该数据集中生成的图例示例如图8所示。

Finally, Figure 9(c) shows clear improvements in FID and IS for MS-COCO as the sample size used for reranking with the contrastive model is increased. This trend continues up to a sample size of 32, after which we observe diminishing returns.
最后,图 9(c) 显示,随着用于对比模型重新排序的样本大小增加,MS-COCO 的 FID 和 IS 都有明显改善。这种趋势一直持续到样本大小为 32,之后我们观察到收益递减。

3.2 Data Overlap Analysis
3.2 数据重叠分析

We used the deduplication procedure described in Radford et al. (2021) to determine which images to remove. For each validation image, we find the closest image in the training data using a contrastive model specifically trained for this task. We then sort the images in descending order by closeness to their nearest matches in the training data. After inspecting the results by hand, we determine the images to remove by manually selecting a conservative threshold designed to minimize the false negative rate.
我们使用了Radford et al. (2021)中描述的去重程序来确定要删除的图像。对于每个验证图像,我们使用专门为此任务训练的对比模型找到训练数据中最接近的图像。然后,我们根据与训练数据中最近匹配图像的接近程度将图像按降序排列。在手动检查结果后,我们通过手动选择一个旨在最小化假阴性率的保守阈值来确定要删除的图像。

3.3 Qualitative Findings
3.3 定性研究结果

We found that our model has the ability to generalize in ways that we did not originally anticipate. When given the caption “a tapir made of accordion…” (Figure 2a), the model appears to draw a tapir with an accordion for a body, or an accordion whose keyboard or bass are in the shape of a tapir’s trunk or legs. This suggests that it has developed a rudimentary ability to compose unusual concepts at high levels of abstraction.
我们发现我们的模型能够以我们最初未预料到的方式进行概括。当给定标题“一个用手风琴制作的貘……”(图 2a)时,模型似乎绘制了一个有着手风琴身体的貘,或者一个其键盘或低音部分呈现为貘的鼻子或腿的手风琴。这表明它已经发展出了一种初步的能力,可以在高层次的抽象中组合不寻常的概念。

Our model also appears to be capable of combinatorial generalization, such as when rendering text (Figure 2b) or when probed on sentences like “an illustration of a baby hedgehog in a christmas sweater walking a dog” (Figure 2c). Prompts like the latter require the model to perform variable binding (Smolensky, 1990; Greff et al., 2020) – it is the hedgehog that is in the christmas sweater, not the dog. We note, however, that the model performs inconsistently on the task, sometimes drawing both animals with christmas sweaters, or drawing a hedgehog walking a smaller hedgehog.
我们的模型似乎也能够实现组合泛化,例如在呈现文本时(图 2b)或在被询问“穿着圣诞毛衣的刺猬走一只狗的插图”这样的句子时(图 2c)。像后者这样的提示要求模型进行变量绑定 (Smolensky, 1990; Greff et al., 2020)——是刺猬穿着圣诞毛衣,而不是狗。然而,我们注意到模型在这个任务上的表现不一致,有时会同时画出两只动物都穿着圣诞毛衣,或者画出一只刺猬牵着一只小刺猬。

To a limited degree of reliability, we also find our model to be capable of zero-shot image-to-image translation controllable by natural language (Figure 2d). When the model is given the caption “the exact same cat on the top as a sketch at the bottom” and the top 15×32153215\times 32 part of the image token grid for a photo of a cat, it is able to draw a sketch of a similar looking cat on the bottom.
在一定程度的可靠性下,我们发现我们的模型能够进行由自然语言控制的零-shot 图像到图像的转换(图 2d)。当模型接收到标题“顶部的确切相同的猫与底部的素描”以及图像标记网格顶部的 15×3215\times 32 部分(猫的照片)时,它能够在底部绘制一个外观相似的猫的素描。

This works with several other kinds of transformations, including image operations (e.g., changing the color of the image, converting it to grayscale, or flipping it upside-down) and style transfer (e.g., drawing the cat on a greeting card, a postage stamp, or a cell phone case). Some transformations, such as those that involve only changing the color of the animal, suggest that the model is capable of performing a rudimentary kind of object segmentation. We provide additional examples of zero-shot image-to-image translation in Section G.
这适用于几种其他类型的变换,包括图像操作(例如,改变图像的颜色、将其转换为灰度图像或将其倒转)和风格迁移(例如,在贺卡、邮票或手机壳上绘制猫)。一些变换,例如仅涉及改变动物颜色的那些,表明该模型能够执行一种基本的对象分割。我们在第 G 节提供了零-shot 图像到图像转换的其他示例。

4 Conclusion
4 结论

We investigate a simple approach for text-to-image generation based on an autoregressive transformer, when it is executed at scale. We find that scale can lead to improved generalization, both in terms of zero-shot performance relative to previous domain-specific approaches, and in terms of the range of capabilities that emerge from a single generative model. Our findings suggest that improving generalization as a function of scale may be a useful driver for progress on this task.
我们研究了一种基于自回归 Transformer 的简单文本到图像生成方法,当它在规模化执行时。我们发现规模化可以提高泛化能力,无论是在与之前特定领域方法相比的零样本性能方面,还是在从单个生成模型中出现的各种能力方面。我们的发现表明,作为规模化函数提高泛化能力可能是解决此任务进展的有用驱动力。

Acknowledgements 致谢

We would like to thank Matthew Knight for reviewing the code release for this work, and Rewon Child, John Schulman, Heewoo Jun, and Prafulla Dhariwal for helpful early feedback on the paper. We would also like to thank Jong Wook Kim for writing the PyTorch package for the contrastive model described in Radford et al. (2019) that we used to rerank the samples from our model.
我们感谢 Matthew Knight 对这项工作的代码发布进行了审查,并感谢 Rewon Child、John Schulman、Heewoo Jun 和 Prafulla Dhariwal 在论文早期阶段提供的宝贵反馈。我们还要感谢 Jong Wook Kim 为 Radford 等人(2019 中描述的对比模型编写了 PyTorch 包,我们用它来重新排序来自我们模型的样本。

References

  • Abadi et al. (2016) Abadi, M., Barham, P., Chen, J., Chen, Z., Davis, A., Dean, J., Devin, M., Ghemawat, S., Irving, G., Isard, M., et al. Tensorflow: A system for large-scale machine learning. In 12th {{\{USENIX}}\} symposium on operating systems design and implementation ({{\{OSDI}}\} 16), pp.  265–283, 2016.
  • Andreas et al. (2017) Andreas, J., Klein, D., and Levine, S. Learning with latent language. arXiv preprint arXiv:1711.00482, 2017.
  • Bengio et al. (2013) Bengio, Y., Léonard, N., and Courville, A. Estimating or propagating gradients through stochastic neurons for conditional computation. arXiv preprint arXiv:1308.3432, 2013.
  • Bowman et al. (2015) Bowman, S. R., Vilnis, L., Vinyals, O., Dai, A. M., Jozefowicz, R., and Bengio, S. Generating sentences from a continuous space. arXiv preprint arXiv:1511.06349, 2015.
  • Chen et al. (2020) Chen, M., Radford, A., Child, R., Wu, J., Jun, H., Luan, D., and Sutskever, I. Generative pretraining from pixels. In International Conference on Machine Learning, pp. 1691–1703. PMLR, 2020.
  • Child et al. (2019) Child, R., Gray, S., Radford, A., and Sutskever, I. Generating long sequences with sparse transformers. arXiv preprint arXiv:1904.10509, 2019.
  • Cho et al. (2020) Cho, J., Lu, J., Schwenk, D., Hajishirzi, H., and Kembhavi, A. X-lxmert: Paint, caption and answer questions with multi-modal transformers. arXiv preprint arXiv:2009.11278, 2020.
  • Deng et al. (2009) Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., and Fei-Fei, L. Imagenet: A large-scale hierarchical image database. In 2009 IEEE conference on computer vision and pattern recognition, pp.  248–255. Ieee, 2009.
  • Dhariwal et al. (2020) Dhariwal, P., Jun, H., Payne, C., Kim, J. W., Radford, A., and Sutskever, I. Jukebox: A generative model for music. arXiv preprint arXiv:2005.00341, 2020.
  • Goodfellow et al. (2014) Goodfellow, I. J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A., and Bengio, Y. Generative adversarial networks. arXiv preprint arXiv:1406.2661, 2014.
  • Greff et al. (2020) Greff, K., van Steenkiste, S., and Schmidhuber, J. On the binding problem in artificial neural networks. arXiv preprint arXiv:2012.05208, 2020.
  • Gregor et al. (2015) Gregor, K., Danihelka, I., Graves, A., Rezende, D., and Wierstra, D. Draw: A recurrent neural network for image generation. In International Conference on Machine Learning, pp. 1462–1471. PMLR, 2015.
  • He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Identity mappings in deep residual networks. In European conference on computer vision, pp.  630–645. Springer, 2016.
  • Heusel et al. (2017) Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. arXiv preprint arXiv:1706.08500, 2017.
  • Higgins et al. (2016) Higgins, I., Matthey, L., Pal, A., Burgess, C., Glorot, X., Botvinick, M., Mohamed, S., and Lerchner, A. beta-vae: Learning basic visual concepts with a constrained variational framework. 2016.
  • Isola et al. (2017) Isola, P., Zhu, J.-Y., Zhou, T., and Efros, A. A. Image-to-image translation with conditional adversarial networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  1125–1134, 2017.
  • Jang et al. (2016) Jang, E., Gu, S., and Poole, B. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144, 2016.
  • Kingma & Ba (2014) Kingma, D. P. and Ba, J. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • Kingma & Welling (2013) Kingma, D. P. and Welling, M. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114, 2013.
  • Koh et al. (2021) Koh, J. Y., Baldridge, J., Lee, H., and Yang, Y. Text-to-image generation grounded by fine-grained user attention. In Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp.  237–246, 2021.
  • Köster et al. (2017) Köster, U., Webb, T. J., Wang, X., Nassar, M., Bansal, A. K., Constable, W. H., Elibol, O. H., Gray, S., Hall, S., Hornof, L., et al. Flexpoint: An adaptive numerical format for efficient training of deep neural networks. arXiv preprint arXiv:1711.02213, 2017.
  • LeCun et al. (1998) LeCun, Y., Bottou, L., Bengio, Y., and Haffner, P. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • Li et al. (2019) Li, W., Zhang, P., Zhang, L., Huang, Q., He, X., Lyu, S., and Gao, J. Object-driven text-to-image synthesis via adversarial training. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  12174–12182, 2019.
  • Lin et al. (2014) Lin, T.-Y., Maire, M., Belongie, S., Hays, J., Perona, P., Ramanan, D., Dollár, P., and Zitnick, C. L. Microsoft coco: Common objects in context. In European conference on computer vision, pp.  740–755. Springer, 2014.
  • Liu et al. (2020) Liu, L., Liu, X., Gao, J., Chen, W., and Han, J. Understanding the difficulty of training transformers. arXiv preprint arXiv:2004.08249, 2020.
  • Loshchilov & Hutter (2017) Loshchilov, I. and Hutter, F. Decoupled weight decay regularization. arXiv preprint arXiv:1711.05101, 2017.
  • Maddison et al. (2016) Maddison, C. J., Mnih, A., and Teh, Y. W. The concrete distribution: A continuous relaxation of discrete random variables. arXiv preprint arXiv:1611.00712, 2016.
  • Mansimov et al. (2015) Mansimov, E., Parisotto, E., Ba, J. L., and Salakhutdinov, R. Generating images from captions with attention. arXiv preprint arXiv:1511.02793, 2015.
  • Micikevicius et al. (2017) Micikevicius, P., Narang, S., Alben, J., Diamos, G., Elsen, E., Garcia, D., Ginsburg, B., Houston, M., Kuchaiev, O., Venkatesh, G., et al. Mixed precision training. arXiv preprint arXiv:1710.03740, 2017.
  • Nguyen et al. (2017) Nguyen, A., Clune, J., Bengio, Y., Dosovitskiy, A., and Yosinski, J. Plug & play generative networks: Conditional iterative generation of images in latent space. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp.  4467–4477, 2017.
  • Oord et al. (2017) Oord, A. v. d., Vinyals, O., and Kavukcuoglu, K. Neural discrete representation learning. arXiv preprint arXiv:1711.00937, 2017.
  • Provilkov et al. (2019) Provilkov, I., Emelianenko, D., and Voita, E. Bpe-dropout: Simple and effective subword regularization. arXiv preprint arXiv:1910.13267, 2019.
  • Radford et al. (2019) Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., and Sutskever, I. Language models are unsupervised multitask learners. 2019.
  • Radford et al. (2021) Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., Krueger, G., and Sutskever, I. Learning transferable visual models from natural language supervision. 2021.
  • Rajbhandari et al. (2019) Rajbhandari, S., Rasley, J., Ruwase, O., and He, Y. Zero: Memory optimization towards training a trillion parameter models. arXiv preprint arXiv:1910.02054, 2019.
  • Razavi et al. (2019) Razavi, A., Oord, A. v. d., and Vinyals, O. Generating diverse high-fidelity images with vq-vae-2. arXiv preprint arXiv:1906.00446, 2019.
  • Reed et al. (2016a) Reed, S., Akata, Z., Mohan, S., Tenka, S., Schiele, B., and Lee, H. Learning what and where to draw. arXiv preprint arXiv:1610.02454, 2016a.
  • Reed et al. (2016b) Reed, S., Akata, Z., Yan, X., Logeswaran, L., Schiele, B., and Lee, H. Generative adversarial text to image synthesis. In International Conference on Machine Learning, pp. 1060–1069. PMLR, 2016b.
  • Rezende et al. (2014) Rezende, D. J., Mohamed, S., and Wierstra, D. Stochastic backpropagation and approximate inference in deep generative models. In International conference on machine learning, pp. 1278–1286. PMLR, 2014.
  • Salimans et al. (2016) Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans. arXiv preprint arXiv:1606.03498, 2016.
  • Salimans et al. (2017) Salimans, T., Karpathy, A., Chen, X., and Kingma, D. P. Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications. arXiv preprint arXiv:1701.05517, 2017.
  • Sennrich et al. (2015) Sennrich, R., Haddow, B., and Birch, A. Neural machine translation of rare words with subword units. arXiv preprint arXiv:1508.07909, 2015.
  • Sharma et al. (2018) Sharma, P., Ding, N., Goodman, S., and Soricut, R. Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp.  2556–2565, 2018.
  • Smolensky (1990) Smolensky, P. Tensor product variable binding and the representation of symbolic structures in connectionist systems. Artificial intelligence, 46(1-2):159–216, 1990.
  • Sun et al. (2017) Sun, C., Shrivastava, A., Singh, S., and Gupta, A. Revisiting unreasonable effectiveness of data in deep learning era. In Proceedings of the IEEE international conference on computer vision, pp.  843–852, 2017.
  • Sun et al. (2020) Sun, X., Wang, N., Chen, C.-Y., Ni, J., Agrawal, A., Cui, X., Venkataramani, S., El Maghraoui, K., Srinivasan, V. V., and Gopalakrishnan, K. Ultra-low precision 4-bit training of deep neural networks. Advances in Neural Information Processing Systems, 33, 2020.
  • Tao et al. (2020) Tao, M., Tang, H., Wu, S., Sebe, N., Wu, F., and Jing, X.-Y. Df-gan: Deep fusion generative adversarial networks for text-to-image synthesis. arXiv preprint arXiv:2008.05865, 2020.
  • Thomee et al. (2016) Thomee, B., Shamma, D. A., Friedland, G., Elizalde, B., Ni, K., Poland, D., Borth, D., and Li, L.-J. Yfcc100m: The new data in multimedia research. Communications of the ACM, 59(2):64–73, 2016.
  • Vaswani et al. (2017) Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., and Polosukhin, I. Attention is all you need. arXiv preprint arXiv:1706.03762, 2017.
  • Vogels et al. (2019) Vogels, T., Karimireddy, S. P., and Jaggi, M. Powersgd: Practical low-rank gradient compression for distributed optimization. arXiv preprint arXiv:1905.13727, 2019.
  • Welinder et al. (2010) Welinder, P., Branson, S., Mita, T., Wah, C., Schroff, F., Belongie, S., and Perona, P. Caltech-ucsd birds 200. 2010.
  • Xu et al. (2018) Xu, T., Zhang, P., Huang, Q., Zhang, H., Gan, Z., Huang, X., and He, X. Attngan: Fine-grained text to image generation with attentional generative adversarial networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp.  1316–1324, 2018.
  • Zhang et al. (2017) Zhang, H., Xu, T., Li, H., Zhang, S., Wang, X., Huang, X., and Metaxas, D. N. Stackgan: Text to photo-realistic image synthesis with stacked generative adversarial networks. In Proceedings of the IEEE international conference on computer vision, pp.  5907–5915, 2017.
  • Zhang et al. (2018) Zhang, H., Xu, T., Li, H., Zhang, S., Wang, X., Huang, X., and Metaxas, D. N. Stackgan++: Realistic image synthesis with stacked generative adversarial networks. IEEE transactions on pattern analysis and machine intelligence, 41(8):1947–1962, 2018.
  • Zhu et al. (2019) Zhu, M., Pan, P., Chen, W., and Yang, Y. Dm-gan: Dynamic memory generative adversarial networks for text-to-image synthesis. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp.  5802–5810, 2019.

Appendix A Details for Discrete VAE

A.1 Architecture

The dVAE encoder and decoder are convolutional (LeCun et al., 1998) ResNets (He et al., 2016) with bottleneck-style resblocks. The models primarily use 3×3333\times 3 convolutions, with 1×1111\times 1 convolutions along skip connections in which the number of feature maps changes between the input and output of a resblock. The first convolution of the encoder is 7×7777\times 7, and the last convolution of the encoder (which produces the 32×32×81923232819232\times 32\times 8192 output used as the logits for the categorical distributions for the image tokens) is 1×1111\times 1. Both the first and last convolutions of the decoder are 1×1111\times 1. The encoder uses max-pooling (which we found to yield better ELB than average-pooling) to downsample the feature maps, and the decoder uses nearest-neighbor upsampling. The precise details for the architectures are given in the files dvae/encoder.py and dvae/decoder.py of the code release.

A.2 Training

def preprocess_image(img, target_res):
h, w = tf.shape(img)[0], tf.shape(img)[1]
s_min = tf.minimum(h, w)
img = tf.image.random_crop(img, 2 * [s_min] + [3])
t_min = tf.minimum(s_min, round(9 / 8 * target_res))
t_max = tf.minimum(s_min, round(12 / 8 * target_res))
t = tf.random.uniform([], t_min, t_max + 1, dtype=tf.int32)
img = tf.image.resize_images(img, [t, t], method=tf.image.ResizeMethod.AREA,
align_corners=True)
img = tf.cast(tf.rint(tf.clip_by_value(img, 0, 255)), tf.uint8)
img = tf.image.random_crop(img, 2 * [target_res] + [channel_count])
return tf.image.random_flip_left_right(img)
Listing 1: TensorFlow (Abadi et al., 2016) image preprocessing code for training dVAE. We use target_res = 256 and channel_count = 3.

The dVAE is trained on the same dataset as the transformer, using the data augmentation code given in Listing 1. Several quantities are decayed during training, all of which use a cosine schedule:

  1. 1.

    The KL weight β𝛽\beta is increased from 00 to 6.66.66.6 over the first 500050005000 updates. Bowman et al. (2015) use a similar schedule based on the sigmoid function.

  2. 2.

    The relaxation temperature τ𝜏\tau is annealed from 111 to 1/161161/16 over the first 150000150000150000 updates. Using a linear annealing schedule for this typically led to divergence.

  3. 3.

    The step size is annealed from 11041superscript1041\cdot 10^{-4} to 1.251061.25superscript1061.25\cdot 10^{-6} over 120000012000001200000 updates.

The decay schedules for the relaxation temperature and the step size are especially important for stability and successful optimization.

We update the parameters using AdamW (Loshchilov & Hutter, 2017) with β1=0.9subscript𝛽10.9\beta_{1}=0.9, β2=0.999subscript𝛽20.999\beta_{2}=0.999, ϵ=108italic-ϵsuperscript108\epsilon=10^{-8}, and weight decay multiplier 104superscript10410^{-4}. We use exponentially weighted iterate averaging for the parameters with decay coefficient 0.9990.9990.999. The reconstruction term in the ELB is a joint distribution over the 256×256×32562563256\times 256\times 3 values for the image pixels, and the KL term is a joint distribution over the 32×32323232\times 32 positions in the spatial grid output by the encoder. We divide the overall loss by 256×256×32562563256\times 256\times 3, so that the weight of the KL term becomes β/192𝛽192\beta/192, where β𝛽\beta is the KL weight. The model is trained in mixed-precision using standard (i.e., global) loss scaling on 646464 16 GB NVIDIA V100 GPUs, with a per-GPU batch size of 888, resulting in a total batch size of 512. It is trained for a total of 300000030000003000000 updates.

A.3 The Logit-Laplace Distribution

The 1subscript1\ell_{1} and 2subscript2\ell_{2} reconstruction objectives are commonly used when training VAEs. These objectives correspond to using Laplace and Gaussian distributions for lnpθ(x|y,z)subscript𝑝𝜃conditional𝑥𝑦𝑧\ln p_{\theta}(x\,|\,y,z) in Equation 1, respectively. There is a strange mismatch in this modeling choice: pixel values lie within a bounded interval, but both of these distributions are supported by the entire real line. Hence, some amount of likelihood will be placed outside the admissible range of pixel values.

We present a variant of the Laplace distribution that is also supported by a bounded interval. This resolves the discrepancy between the range of the pixel values being modeled and the support of the distribution used to model them. We consider the pdf of the random variable obtained by applying the sigmoid function to a Laplace-distributed random variable. This pdf is defined on (0,1)01(0,1) and is given by

f(x|μ,b)=12bx(1x)exp(|logit(x)μ|b);𝑓conditional𝑥𝜇𝑏12𝑏𝑥1𝑥logit𝑥𝜇𝑏f(x\,|\,\mu,b)=\frac{1}{2bx(1-x)}\exp\left(-\frac{|\operatorname{logit}(x)-\mu|}{b}\right); (2)

we call it the logit-Laplace distribution. We use the logarithm of the RHS of Equation 2 as the reconstruction term for the training objective of the dVAE.

The decoder of the dVAE produces six feature maps representing the sufficient statistics of the logit-Laplace distribution for the RGB channels of the image being reconstructed. The first three feature maps represent the μ𝜇\mu parameter for the RGB channels, and the last three represent lnb𝑏\ln b. Before feeding an image into the dVAE encoder, we transform its values using φ:[0,255](ϵ,1ϵ):𝜑0255italic-ϵ1italic-ϵ\varphi:[0,255]\to(\epsilon,1-\epsilon), which is given by

φ:x12ϵ255x+ϵ.:𝜑maps-to𝑥12italic-ϵ255𝑥italic-ϵ\varphi:x\mapsto\frac{1-2\epsilon}{255}x+\epsilon. (3)

This restricts the range of the pixel values to be modeled by the dVAE decoder to (ϵ,1ϵ)italic-ϵ1italic-ϵ(\epsilon,1-\epsilon), which avoids numerical problems arising from the x(1x)𝑥1𝑥x(1-x) in Equation 2. We use ϵ=0.1italic-ϵ0.1\epsilon=0.1. To reconstruct an image for manual inspection or computing metrics, we ignore lnb𝑏\ln b and compute x^=φ1(sigmoid(μ))^𝑥superscript𝜑1sigmoid𝜇\hat{x}=\varphi^{-1}(\operatorname{sigmoid}(\mu)), where μ𝜇\mu is given by the first three feature maps output by the dVAE decoder.111111See notebooks/usage.ipynb of the code release for an example.

Appendix B Details for Transformer

B.1 Architecture

Refer to caption
Figure 10: Illustration of the embedding scheme for a hypothetical version of our transformer with a maximum text length of 6 tokens. Each box denotes a vector of size dmodel=3968subscript𝑑model3968d_{\mathrm{model}}=3968. In this illustration, the caption has a length of 4 tokens, so 2 padding tokens are used (as described in Section 2.2). Each image vocabulary embedding is summed with a row and column embedding.
Refer to caption
(a) Row attention mask.
Refer to caption
(b) Column attention mask.
Refer to caption
(c) Column attention mask with transposed image states.
Refer to caption
(d) Convolutional attention mask.
Figure 11: Illustration of the three types of attention masks for a hypothetical version of our transformer with a maximum text length of 6 tokens and image length of 16 tokens (i.e., corresponding to a 4×4444\times 4 grid). Mask (a) corresponds to row attention in which each image token attends to the previous 5 image tokens in raster order. The extent is chosen to be 5, so that the last token being attended to is the one in the same column of the previous row. To obtain better GPU utilization, we transpose the row and column dimensions of the image states when applying column attention, so that we can use mask (c) instead of mask (b). Mask (d) corresponds to a causal convolutional attention pattern with wraparound behavior (similar to the row attention) and a 3×3333\times 3 kernel. Our model uses a mask corresponding to an 11×11111111\times 11 kernel.

Our model is a decoder-only sparse transformer of the same kind described in Child et al. (2019), with broadcasted row and column embeddings for the part of the context for the image tokens. A complete description of the embedding scheme used in our model is shown in Figure 10. We use 64 attention layers, each of which uses 62 attention heads with a per-head state size of 64.

The model uses three kinds of sparse attention masks, which we show in Figure 11. The convolutional attention mask (Figure 11(d)) is only used in the last self-attention layer. Otherwise, given the index i𝑖i of a self-attention layer (with i[1,63]𝑖163i\in[1,63]), we use the column attention mask (Figure 11(c)) if i2mod4=0modulo𝑖240i-2\!\!\mod 4=0, and row attention otherwise. E.g., the first four self-attention layers use “row, column, row, row”, respectively. With the exception of the convolutional attention mask, which we found to provide a small boost in performance over the row and dense causal attention masks when used in the final self-attention layer, this is the same configuration used in Child et al. (2019).

B.2 Training

def preprocess_image(img, target_res):
h, w = tf.shape(img)[0], tf.shape(img)[1]
s_min = tf.minimum(h, w)
off_h = tf.random.uniform([], 3 * (h - s_min) // 8,
tf.maximum(3 * (h - s_min) // 8 + 1, 5 * (h - s_min) // 8),
dtype=tf.int32)
off_w = tf.random.uniform([], 3 * (w - s_min) // 8,
tf.maximum(3 * (w - s_min) // 8 + 1, 5 * (w - s_min) // 8),
dtype=tf.int32)
# Random full square crop.
img = tf.image.crop_to_bounding_box(img, off_h, off_w, s_min, s_min)
t_max = tf.minimum(s_min, round(9 / 8 * target_res))
t = tf.random.uniform([], target_res, t_max + 1, dtype=tf.int32)
img = tf.image.resize_images(img, [t, t], method=tf.image.ResizeMethod.AREA,
align_corners=True)
img = tf.cast(tf.rint(tf.clip_by_value(img, 0, 255)), tf.uint8)
# We don’t use hflip aug since the image may contain text.
return tf.image.random_crop(img, 2 * [target_res] + [channel_count])
Listing 2: TensorFlow (Abadi et al., 2016) image preprocessing code for training the transformer. We use target_res = 256 and channel_count = 3.

When training the transformer, we apply data augmentation to the images before encoding them using the dVAE encoder. We use slightly different augmentations from the ones used to train the dVAE; the code used for this is given in Listing 2. We also apply 10% BPE dropout when BPE-encoding the captions for training. The model is trained using per-resblock scaling (see Section 2.4) and gradient compression (see Section 2.5) with total compression rank 896 (so that each GPU uses a compression rank of 112 for its parameter shards). As shown in Table 1, this results in a compression rate of about 86%, which we analyze in Section E.1.

We update the parameters using AdamW with β1=0.9subscript𝛽10.9\beta_{1}=0.9, β2=0.96subscript𝛽20.96\beta_{2}=0.96, ϵ=108italic-ϵsuperscript108\epsilon=10^{-8}, and weight decay multiplier 4.51024.5superscript1024.5\cdot 10^{-2}. We clip the decompressed gradients by norm using a threshold of 4, prior to applying the Adam update. Gradient clipping is only triggered during the warm-up phase at the start of training. To conserve memory, most Adam moments (see Section D for details) are stored in 16-bit formats, with a 1-6-9 format for the running mean (i.e., 1 bit for the sign, 6 bits for the exponent, and 9 bits for the significand), and a 0-6-10 format for the running variance. We clip the estimate for running variance by value to 5 before it is used to update the parameters or moments. Finally, we apply exponentially weighted iterate averaging by asynchronously copying the model parameters from the GPU to the CPU once every 25 updates, using a decay coefficient of 0.99.

We trained the model using 1024, 16 GB NVIDIA V100 GPUs and a total batch size of 102410241024, for a total of 430000430000430000 updates. At the start of training, we use a linear schedule to ramp up the step size to 4.51044.5superscript1044.5\cdot 10^{-4} over 500050005000 updates, and halved the step size each time the training loss appeared to plateau. We did this a total of five times, ending training with a final step size that was 32 times smaller than the initial one. We reserved about 606000606000606000 images for validation, and did not observe overfitting at any point during training.

Appendix C Details for Data Collection

In order to train the 12-billion parameter transformer, we created a dataset of a similar scale to JFT-300M by collecting 250 million text-image pairs from the internet. As described in Section 2.3, this dataset incorporates Conceptual Captions, the text-image pairs from Wikipedia, and a filtered subset of YFCC100M. We use a subset of the text, image, and joint text and image filters described in Sharma et al. (2018) to construct this dataset. These filters include discarding instances whose captions are too short, are classified as non-English by the Python package cld3, or that consist primarily of boilerplate phrases such as “photographed on <date>”, where <date> matches various formats for dates that we found in the data. We also discard instances whose images have aspect ratios not in [1/2,2]122[1/2,2]. If we were to use to very tall or wide images, then the square crops used during training would likely exclude objects mentioned in the caption.

Appendix D Guidelines for Mixed-Precision Training

Refer to caption
Figure 12: Plot of per-resblock gradient scales for a 2.8-billion parameter text-to-image transformer trained without gradient compression. The x𝑥x-axis is parameter updates, and the y𝑦y-axis is the base-2 logarithm of the gradient scale. Darkest violet corresponds to the first resblock, and brightest yellow corresponds to the last (of which there are 128 total). The gradient scale for the second MLP resblock hovers at around 224superscript2242^{24}, while the others stay within a 4-bit range. The extent of this range increases as the model is made larger.

The most challenging part of this project was getting the model to train in 16-bit precision past one billion parameters. We were able to do this after detecting for underflow in various parts of training, and revising the code to eliminate it. We developed a set of guidelines as a result of this process that we present here.121212Fewer of these guidelines may be necessary on hardware like the TPU that has native support for the bfloat16 format, since the larger 8-bit exponent range makes underflow less likely to occur.

  1. 1.

    Use per-resblock gradient scaling (Figure 4) instead of standard loss scaling. Our model uses 128 gradient scales, one for each of its resblocks. All of the gradient scales are initialized to M213𝑀superscript213M\cdot 2^{13}, where M𝑀M is the number of data-parallel replicas (i.e., the number of GPUs). In our setup, each grad scale is multiplied by 21/1000superscript2110002^{1/1000} at every parameter update when there are no nonfinite values for any parameter gradient in that resblock. Otherwise, we divide the grad scale by 22\sqrt{2} and skip the update. We also disallow consecutive divisions of the same grad scale within a window of 125125125 updates. All grad scales are clamped to the range [M27,M224]𝑀superscript27𝑀superscript224[M\cdot 2^{7},M\cdot 2^{24}] after being updated. Figure 12 shows the gradient scales in the early phase of training for a 2.8-billion parameter model.

  2. 2.

    Only use 16-bit precision where it is really necessary for performance. In particular, store all gains, biases, embeddings, and unembeddings in 32-bit precision, with 32-bit gradients (including for remote communication) and 32-bit Adam moments. We disable gradient compression for these parameters (though PowerSGD would not make sense for 1D parameters like gains and biases). The logits for the text and image tokens are computed and stored in 32-bit precision. We found that storing the embeddings in 16-bit precision sometimes caused divergence early in optimization, and using 16-bit logits resulted in a small shift in the training curve, so we switched to use 32-bit precision out of an abundance of caution.

  3. 3.

    Avoid underflow when dividing the gradient. For data-parallel training, we need to divide the gradients by the total number of data-parallel workers M𝑀M. One way to do this is to divide the loss by the per-machine batch size, and then divide the parameter gradients by M𝑀M before summing them over the machines (using all-reduce). To save time and space, the gradients are usually computed and stored in 16-bit precision. When M𝑀M is large, this division could result in underflow before the gradients are summed. On the other hand, if we attempt to sum the gradients first and then divide them later, we could encounter overflow in the all-reduce.

    Our solution for this problem attempts to minimize the loss of information in the division prior to the all-reduce, without danger of overflow. To do this, we divide the loss by the overall batch size (which includes M𝑀M as a factor) rather than the per-machine batch size, and multiply the gradient scales by M𝑀M to compensate, as described in (1). Then, prior to the all-reduce operation, we divide the gradients by a constant that was tuned by hand to avoid both underflow and overflow. This was done by inspecting histograms of the exponents (i.e., base-2 logarithms) of the absolute values of the scalar components of the per-parameter gradients. Since the gradient scaling keeps the gradients close to right end of the exponent range of the 16-bit format, we found that the same constant worked well for all parameters in the model with 16-bit gradients. When using PowerSGD, we chose different constants for the P𝑃P and Q𝑄Q matrices.

Appendix E Details for Distributed Optimization

We use PowerSGD (Vogels et al., 2019) to compress the gradients with respect to all parameters except the embeddings, unembeddings, gains, and biases. In Section E.1, we derive an expression for the reduction in the amount of data communicated as a function of the compression rank and model size. In Section E.2, we present a detailed overview of our adaptation of PowerSGD, and the modifications we had to make in order to fix performance regressions, some of which only manifest at billion-parameter scale.

E.1 Bandwidth Analysis

Parameter Names Parameter Shard Gradient Shape (No Compression) P𝑃P shape Q𝑄Q shape
qkv and post-attention matrices d×(d/m)𝑑𝑑𝑚d\times(d/m) d×(r/m)𝑑𝑟𝑚d\times(r/m) (r/m)×(d/m)𝑟𝑚𝑑𝑚(r/m)\times(d/m)
First MLP matrix d×(4d/m)𝑑4𝑑𝑚d\times(4d/m) d×(r/m)𝑑𝑟𝑚d\times(r/m) (r/m)×(4d/m)𝑟𝑚4𝑑𝑚(r/m)\times(4d/m)
Second MLP matrix (4d/m)×d4𝑑𝑚𝑑(4d/m)\times d (4d/m)×(r/m)4𝑑𝑚𝑟𝑚(4d/m)\times(r/m) (r/m)×d𝑟𝑚𝑑(r/m)\times d
Total size 12d2/m12superscript𝑑2𝑚12d^{2}/m (5drm+4dr)/m25𝑑𝑟𝑚4𝑑𝑟superscript𝑚2(5drm+4dr)/m^{2} (drm+8dr)/m2𝑑𝑟𝑚8𝑑𝑟superscript𝑚2(drm+8dr)/m^{2}
Table 2: We analyze the amount of data sent from each GPU on a given machine to GPUs on other machines, in the case where we shard the parameters among the m𝑚m GPUs on each machine. Here, r𝑟r denotes the rank used for compression, and d𝑑d the transformer hidden size. The compression ratio is given by the sum of the last two columns of the last row, divided by the first column of the last row. This comes out to r(m+2)/(2dm)𝑟𝑚22𝑑𝑚r(m+2)/(2dm), which for m=8𝑚8m=8 is 5r/8d5𝑟8𝑑5r/8d.

Gradient compression uses the factorization GPQt𝐺𝑃superscript𝑄𝑡G\approx PQ^{t}, where P𝑃P and Q𝑄Q both have rank r𝑟r. Instead of using a single all-reduce to transmit G𝐺G, we use two, smaller all-reduces to transmit both P𝑃P and Qtsuperscript𝑄𝑡Q^{t} in succession. Hence, the compression ratio is the sum of the sizes of the P𝑃P and Q𝑄Q matrices divided by the sum of the sizes of the G𝐺G matrices. We shard along axis 1 for all parameters except for the second MLP matrix. The derivation of the compression ratio in our setup is given in Table 2. We note that the choice of shard axis changes the compression ratio for the MLP matrices. Finally, this analysis excludes the embeddings, unembeddings, gains, and biases, for which we do not use compression. The total fraction of the bandwidth used by these parameters becomes smaller as the model size is increased.

E.2 Implementation Details

We describe the steps in our implementation of PowerSGD in detail, since these details were crucial in getting it to work efficiently and reliably at billion-parameter scale.

  1. 1.

    Our training setup uses a combination of parameter sharding and gradient compression, as described in Section 2.5. During backpropagation, while recomputing the activations and computing the gradients for the current resblock, we prefetch the parameters for the preceding resblock using all-gather. Once each GPU has computed the gradient with respect to a full parameter matrix, we compute the average of the slice of the gradient corresponding to the GPU’s parameter shard, and discard the full gradient immediately to conserve memory. This average is taken over all of the GPUs on a machine using reduce-scatter.

  2. 2.

    If there are no nonfinite values in the result of the reduce-scatter (which could be caused by overflow in backpropagation or the reduce-scatter), we divide the result by the resblock’s gradient scale, and add it to the error buffer (i.e., the buffer used for error correction). Otherwise, we do nothing and proceed with backpropagation; a single nonfinite value in the gradient means that the entire update will be skipped, which happens about 5% of the time. The error buffer uses the same 1-6-9 format used for the Adam mean, which we describe in Section B.2; the larger exponent range ensures that this division does not result in underflow. Adding the gradients directly to the error buffers avoids redundantly allocating another set of buffers of size equal to the parameter shard gradients.

  3. 3.

    Once the reduce-scatter operations for the resblock have finished, we schedule the operations to compute the P𝑃P matrices from the errors buffers and the Q𝑄Q matrices, whose values are fixed at the start of training (see Section 2.5). Both the P𝑃P and Q𝑄Q matrices are stored in 1-6-9 format and have their values scaled by predetermined constants, as discussed in Section D.

  4. 4.

    Once each GPU has computed the P𝑃P matrices for the parameter shards in a resblock, they are averaged with the P𝑃P matrices from the GPUs with the same ordinal on all other machines, using a single, grouped all-reduce operation. This all-reduce is carried out in the 1-6-9 format, using a custom kernel. The grouping results in better bandwidth utilization, since it avoids scheduling many all-reduce calls for smaller, individual parameters, each of which carries some overhead. We clamp any infinities in the results of the all-reduce to the maximum value of the 1-6-9 format (which is slightly less than 16), retaining the sign. With our choice of scaling factors for the P𝑃P and Q𝑄Q matrices, this clamping happens very rarely.

  5. 5.

    Once the all-reduce operation for the P𝑃P matrices for a resblock have finished, we orthogonalize the columns of the resulting matrices. We use a custom Householder orthogonalization kernel rather than Gram-Schmidt, as we found the latter to be numerically unstable. We also add ϵIm×ritalic-ϵsubscript𝐼𝑚𝑟\epsilon I_{m\times r} to P𝑃P in order to ensure that the result is not near rank-deficient, where ϵ=106italic-ϵsuperscript106\epsilon=10^{-6}. Here, Im×rsubscript𝐼𝑚𝑟I_{m\times r} is a rectangular matrix of the same size as the P𝑃P matrix to which it is added; it contains the r×r𝑟𝑟r\times r identity matrix and has zeros elsewhere. The orthogonalizalied P𝑃P matrices are stored in 1-6-9 format, but without scaling.

  6. 6.

    Once the P𝑃P matrices for a resblock have been orthogonalized, we schedule the operations to compute the new Q𝑄Q matrices from the error buffers and the P𝑃P matrices.

  7. 7.

    Once the new Q𝑄Q matrices for a resblock have been computed, we schedule another grouped all-reduce, similar to what we did for the P𝑃P matrices. As in step (4), we clamp all infinities in the results of the all-reduce to the maximum value of the 1-6-9 format, retaining the sign. The error buffers for the resblock have now been decomposed into low-rank factors P𝑃P and Qtsuperscript𝑄𝑡Q^{t}.

  8. 8.

    The gradients for all parameters that are not compressed are grouped together into a single, 32-bit precision all-reduce. Section D explains why we use 32-bit precision for these parameters and their gradients.

  9. 9.

    Once all GPUs on a machine have finished steps (7) and (8) for every resblock in the model, the values of the P𝑃P and Q𝑄Q matrices for the same parameter shard on all machines will be identical. We then compute the global gradient norm, which is the sum of two quantities: (a) the sum of the squared Frobenius norms of the Q𝑄Q matrices over all of the parameter shards on a machine, and (b) the sum of the squared norms of the gradients for the parameter shards that do not use compression, taken over all such parameter shards on a machine. We need to compute this value for gradient clipping (see Section B.2).

  10. 10.

    While computing the global norm, we also synchronize the information from step (2) about which parameter shard gradients contained nonfinite values after the reduce-scatter. After doing this, we have two pieces of information for each parameter shard: (a) whether its error buffer from step (2) contains nonfinite values on the current GPU, and (b) whether P𝑃P or Q𝑄Q contains nonfinite values. We cannot rely on the values of the P𝑃P and Q𝑄Q matrices to determine (b), since we clamp infinities as described in step (4). If we find that the gradient with respect to any parameter shard on the machine contains nonfinite values, then we set the global norm to infinity.

  11. 11.

    Once all of the all-reduces have finished and the global norm has been computed, we can apply the parameter updates. Like backpropagation, the parameter updates proceed resblock-by-resblock. The first step is to compute the decompressed gradients by forming the product PQt𝑃superscript𝑄𝑡PQ^{t} for all parameters in a given resblock. To avoid overflow, these products are computed in 32-bit precision. We can then apply the Adam update to the parameters using the decompressed gradients and the global norm computed in step (9). If the global norm is not finite, then the update to the parameters and Adam moments is skipped. We note that the decompressed gradient must be divided by the scale of the Q𝑄Q matrix (the P𝑃P matrix is stored without scaling after orthogonalization).

  12. 12.

    The second step is the update to the error buffers. First, we use the results from step (10) to check if the P𝑃P and Q𝑄Q matrices for a given parameter shard contain only finite values. If this is the case, then we divide the decompressed gradient by the total number of machines, and subtract it from the current value for the error buffer. This sets the error buffer to the difference between the “local” gradient averaged over the GPUs on the machine using reduce-scatter, and the “remote” decompressed gradient (i.e., the “error”). If either P𝑃P or Q𝑄Q contains nonfinite values, then we check if the error buffer computed in step (2) contains only finite values. If it does, then we preserve its value and do nothing. If it does not, then we set it to zero. The purpose of this tedious logic is to set an error buffer to zero only when we must do so, because it has been contaminated with nonfinite values. We found that error buffers getting set to zero too frequently by gradient scaling events leads to performance regressions.

  13. 13.

    The parameter shards whose gradients are not compressed are updated separately.

We also note the following important optimizations:

  1. 1.

    There are several opportunities for overlap between compute and communication in the above steps. For example, while we are running step (2) for resblock i𝑖i, we can proceed to steps (3)–(8) for all resblocks j>i𝑗𝑖j>i. Exploiting opportunities for overlap is necessary to achieve good performance.

  2. 2.

    We throttle specific operations that are liable to exhaust all available memory. For example, we only prefetch the parameters from the preceding resblock when the reduce-scatter operations have finished for the current one. Otherwise, we risk running out of memory by holding on to the full parameters. We also throttle the Adam updates, so that we do not decompress all of the gradients at once.

  3. 3.

    There are two places in the implementation where the transposition matters: (a) the choice of shard axis for the MLP matrices and (b) whether we compute the low-rank factorization for a gradient or its transpose. The former influences the bandwidth analysis, which we present in Section E.1. The latter influences the cost of the orthogonalization. Suppose that the gradient G𝐺G is m×n𝑚𝑛m\times n and its low-rank factors P𝑃P and Qtsuperscript𝑄𝑡Q^{t} are m×r𝑚𝑟m\times r and r×n𝑟𝑛r\times n, respectively, with rm,nmuch-less-than𝑟𝑚𝑛r\ll m,n. To make orthogonalization cheaper, we transpose G𝐺G appropriately so that mn𝑚𝑛m\leqslant n.

    At first glance, it may seem like a limitation that the NCCL all-gather and reduce-scatter primitives shard along axis 0 only. We may need to transpose some matrices before and after communication operations because of (a) and (b), which would require additional time and potentially special care to avoid out-of-memory errors. In fact, we never actually needed to do this. This is because we stored some of the parameters in their transposed formats and exploited the transpose_a and transpose_b parameters of the matrix multiplication kernels used in forward propagation, backpropagation, and steps (1)–(13) above. This allowed us to avoid explicit transposition while retaining the freedom to choose how to handle (a) and (b).

  4. 4.

    In step (12) above, we note that setting the error buffers to zero too often can cause performance regressions. We wanted to avoid doing this when resuming training from a checkpoint, which happens more frequently for larger jobs as it is likely that a machine will periodically fail. Naively, this would require uploading the error buffers from all of the machines along with the model checkpoints. Since we use a total of 128 machines for training, this would lead to 128 times greater storage usage, which is extremely wasteful.

    Fortunately, this is unnecessary, as error correction depends only on the sum of the error buffers. This property follows from linearity and the sequence of operations used by PowerSGD. Hence, it suffices to store the sums of the errors buffers taken across all GPUs with the same ordinal. When resuming from a checkpoint, we can divide the error buffers by the total number of machines and broadcast them.

Appendix F Details for Human Evaluation Experiments

Refer to caption
Figure 13: Example task interface shown to workers.

We start with a list of 100010001000 captions and generate one sample image per model per caption. Captions and sample images are then used to create 100010001000 image comparison tasks per experiment, which we submitted to Amazon’s Mechanical Turk. Each task was answered by five distinct workers. Workers were asked to compare two images and answer two questions about them: (1) which image is most realistic, and (2) which image best matches the shared caption. The experimental setup provided to workers is shown in Figure 13. One worker’s answers were disqualified due to a high rate of disagreement with other workers combined with a fast answer velocity (with many submission times under 4 seconds); all other worker answers were kept.

Appendix G Zero-Shot Image-to-Image Translation

Refer to caption Refer to caption
Refer to caption Refer to caption
(a) “the exact same cat on the top as a sketch on the bottom”
Refer to caption Refer to caption
Refer to caption Refer to caption
(b) “the exact same photo on the top reflected upside-down on the bottom”
Refer to caption Refer to caption
Refer to caption Refer to caption
(c) “2 panel image of the exact same cat. on the top, a photo of the cat. on the bottom, an extreme close-up view of the cat in the photo.”
Refer to caption Refer to caption
Refer to caption Refer to caption
(d) “the exact same cat on the top colored red on the bottom”
Refer to caption Refer to caption
Refer to caption Refer to caption
(e) “2 panel image of the exact same cat. on the top, a photo of the cat. on the bottom, the cat with sunglasses.”
Refer to caption Refer to caption
Refer to caption Refer to caption
(f) “the exact same cat on the top as a postage stamp on the bottom”
Figure 14: Further examples of zero-shot image-to-image translation.

Figure 14 shows further examples of zero-shot image-to-image translation, which we discussed in Section 3.3. We did not anticipate that this capability would emerge, and made no modifications to the training procedure to encourage it.