这是用户在 2024-6-16 20:55 为 https://zhuanlan.zhihu.com/p/660671886 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
一文打通PyTorch与flax (JAX的神经网络库)

一文打通PyTorch与flax (JAX的神经网络库)

59 人赞同了该文章
发布于 2023-10-11 11:53・IP 属地北京

上一篇博客一文打通PyTorch与JAX介绍了PyTorch与JAX的底层API之间的关系,打通了PyTorch与JAX的底层联系。然而这还不足以流畅地使用JAX。本文继续深入分析,JAX的神经网络库flax与PyTorch的nn模块之间的关系。

回顾PyTorch nn模块的使用方式

在PyTorch中,我们新建一个模块的方式很简单,只需要继承nn.Module,并写好__init__forward两个函数:

from torch import nn

class MyMod(nn.Module):
    def __init__(self, arg_model):
        super().__init__()
        self.params = init_params(arg_model)
    
    def forward(self, x):
        y = forward_func(self.params, x)

这里的重点在于super().__init__()函数,它隐藏了非常多的nn.Module细节。我们可以初始化一个空的模块mod = nn.Module(),通过dir(mod)看到nn.Module非常多的内部状态,包括但不限于_buffers/_parameters/_forward_hooks等等。

flax Module的初始化方式

JAX的函数式编程最害怕隐藏的状态,因此基本不允许一个没有参数的函数返回一个包含状态的对象。所以,直接初始化flax中的神经网络模块基础类flax.linen.Module会报错,它的正确初始化方式十分复杂。

在flax中,要新建一个模块,也是继承Module类,但是我们不能写__init__函数。继承的方式是通过类型标注增加新的配置参数:

import flax
import typing

class MyMod(flax.linen.Module):
    arg_model: typing.Any

flax的主要工作发生在继承时,它为flax.linen.Module注册了__init_subclass__函数,当我们继承flax.linen.Module得到一个新的类时,这个类已经被flax改得面目全非了。

最重要的一点修改,就是所有继承的类被自动改成了dataclass类型:

import dataclasses
dataclasses.is_dataclass(MyMod) # True

其初始化函数MyMod.__init__,则被改成了接收三个参数的函数:

MyMod.__init__(
    self,
    arg_model: typing.Any,
    parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x11a166d40>,
    name: Optional[str] = None,
) -> None

其中,第一个参数arg_model是我们通过类型标注生成的,剩下两个参数是自动生成的。

这里尤其要注意,dataclass非常依赖类型标注。我们可以不提供参数的默认值,但必须提供类型标注。下面这个例子中,没有提供类型标注,生成的初始化函数就不对了。

class MyMod(flax.linen.Module):
    arg_model = 2
MyMod(arg_model=3) # TypeError: MyMod.__init__() got an unexpected keyword argument 'arg_model'

如果不知道什么类型或者不想标注类型,可以像上面的例子一样,直接标注为任意类型Any。这些类型标注只是用来帮助理解代码与生成函数,并不会对实际运行产生限制(标注为int类型的参数可以初始化为任意类型的值)。

注意,flax模块的初始化只是简单地记住了一些构造参数,我们需要在另外的函数里面使用这些参数。

flax Module的计算函数

flax的计算函数比较特别,参数初始化与计算在同一个函数里面:

import flax
import typing
import jax.numpy as jnp
import jax

class MyMod(flax.linen.Module):
    arg_model: typing.Any

    @flax.linen.compact
    def __call__(self, inputs):
        params = self.param(
            'whatever',
            flax.linen.initializers.uniform(),
            (jnp.shape(inputs)[-1], self.arg_model),
            jnp.float32,
        )

        y = inputs @ params
        return y

也就是说,在初始化参数时,使用定制的self.param函数得到参数,然后用这个参数去进行计算,得到输出。注意,这里的参数"whatever"可以是任意名字,但是不能是"params"

有了这个函数,上文一文打通PyTorch与JAX提到的init/apply等函数都可以用了。

mod = MyMod(arg_model=233)
key1, key2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(key1, (64, 10))
output1, params = mod.init_with_output(key2, x) # Initialization call
output2 = mod.apply(params, x) # forward call
print(((output1 - output2) ** 2).sum()) # they are the same

在flax中使用现有模块

下面展示一个示例,如何在flax中构建一个模块并使用它:

import jax.numpy as jnp
from flax import linen as nn

class ConvLayerNormLeakyReLU(nn.Module):
    output_channels: int
    kernel_size: int
    leaky_slope: float

    @nn.compact
    def __call__(self, x):
        # Convolution layer
        x = nn.Conv(
            features=self.output_channels,
            kernel_size=self.kernel_size,
            kernel_init=nn.initializers.lecun_normal()
        )(x)

        # Layer normalization
        x = nn.LayerNorm()(x)

        # Leaky ReLU activation
        return nn.leaky_relu(x, negative_slope=self.leaky_slope)

input_data = jnp.ones((1, 32, 32, 3))  # Batch of 1, 32x32 image with 3 channels
model = ConvLayerNormLeakyReLU(output_channels=64, kernel_size=(3, 3), leaky_slope=0.01)
output_data, params = model.init_with_output(key, input_data)
output_data = model.apply(params, input_data)

jax/flax使用的是类似TensorFlow形式的API,模块只包含输出的参数(比如卷积的输出channel数),不包含输入的参数,输入的参数由输入数据直接计算。

还有另一种更加类似PyTorch的使用方式:

import jax.numpy as jnp
from flax import linen as nn

class ConvLayerNormLeakyReLU(nn.Module):
    output_channels: int
    kernel_size: int
    leaky_slope: float

    def setup(self):
        # Define the convolution and layer normalization sub-modules in setup
        self.conv = nn.Conv(
            features=self.output_channels,
            kernel_size=self.kernel_size,
            kernel_init=nn.initializers.lecun_normal()
        )
        self.layer_norm = nn.LayerNorm()

    @nn.compact
    def __call__(self, x):
        x = self.conv(x)          # Apply convolution
        x = self.layer_norm(x)    # Apply layer normalization
        return nn.leaky_relu(x, negative_slope=self.leaky_slope)  # Apply leaky ReLU

input_data = jnp.ones((1, 32, 32, 3))  # Batch of 1, 32x32 image with 3 channels
model = ConvLayerNormLeakyReLU(output_channels=64, kernel_size=(3, 3), leaky_slope=0.01)
output_data, params = model.init_with_output(key, input_data)
output_data = model.apply(params, input_data)

以上两种方式基本上是等价的。

有几点需要澄清的地方:

  1. 不使用setup函数时,nn.Conv生成的模块实际存储在self.Conv_0里面,Conv_0是自动生成的名字;使用了setup函数时,我们手动指定了子模块名conv。不论是否使用setup函数,直接输出model.conv或者model.Conv_0都会报错,flax禁止直接访问这些子模块,它们只能在apply/init等函数调用时被使用。
  2. 在我们调用output_data = model.apply(params, input_data)时,flax会递归地调用子模块的apply函数,并且为每个函数的调用维护一个self.scope,里面记录了每个模块的一些临时状态,用以在创建新的模块时自动记录父模块及子模块的名字,这就是上文中__init__函数剩余两个参数的意义。

为了说明scope的作用,我们可以构造嵌套的模块并在模块中打印scope信息:

import jax
import jax.numpy as jnp
from flax import linen as nn

class Leaf(nn.Module):
    slope: float

    @nn.compact
    def __call__(self, x):
        print((self.scope, self.scope.name, self.scope.parent))
        return x * self.slope

class Compound(nn.Module):
    n: int

    def setup(self):
        # Create a list of 'Leaf' modules
        self.leaves = [Leaf(2.0) for _ in range(self.n)]

    def __call__(self, x):
        for leaf in self.leaves:
            x = leaf(x)
        return x

# Example usage:
model = Compound(n=3)
params = model.init(jax.random.PRNGKey(0), jnp.array(1.0))

# Applying Leaf with slope=2.0 for three times, so result should be 8.0 (doubling thrice)
result = model.apply(params, jnp.array(1.0))
print(result)

输出结果:

(<flax.core.scope.Scope object at 0x2982d7d30>, 'leaves_0', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x298377a90>, 'leaves_1', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x298377ca0>, 'leaves_2', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x2983b96c0>, 'leaves_0', <flax.core.scope.Scope object at 0x2982d7d60>)
(<flax.core.scope.Scope object at 0x2983bb550>, 'leaves_1', <flax.core.scope.Scope object at 0x2982d7d60>)
(<flax.core.scope.Scope object at 0x2983bb400>, 'leaves_2', <flax.core.scope.Scope object at 0x2982d7d60>)

从上述代码的输出,我们可以看出,scope是每个模块在它被调用时创建的临时变量(init/apply两次调用的scope对象不同),它的parent指向调用它的父模块对应的scope,name则为它在父模块中的名字。画成一棵树的样子,大概就是:

 Compound model (parent=None, name=None)
    |
    |-----< parent=Compond
            name=leaves_0>
    |       |
    |       v
    |     Leaf
    |
    |-----< parent=Compond
            name=leaves_0>
    |       |
    |       v
    |     Leaf
    |
    |-----< parent=Compond
            name=leaves_2>
            |
            v
          Leaf

实际在运行时,当前存在的scope总是连成一条链表,对应于当前活跃的模块调用栈。

PyTorch Module 与 flax Module 之间的转换

以上内容可以总结为一张图:

想把PyTorch的Module转成flax的代码,需要注意以下五点:

  1. 使用类型注释来声明构造参数,不得定义__init__函数
  2. 如果模块本身有参数(除了子模块的参数之外),参数初始化与后续查找均用self.param函数
  3. __call__函数里同时定义参数初始化与计算逻辑,并用@flax.linen.compact修饰(也可以把子模块的初始化放在setup函数里)
  4. 模型初始化时需要显式声明随机种子与样例输入
  5. 后续每次调用时需要输入参数params

总结

PyTorch框架的核心内容发生在模块对象被创建时的super().__init__()函数,大部分学习Python的人都能理解初始化时的继承逻辑;而flax框架的核心工作发生在模块类的继承时,大部分学习Python的人从来没见过继承一个类的时候还能发生这么多事情,所以导致flax看起来就像是黑魔法。但是本质上来说,flax使用的技术没有超过Python的范畴,因此我认为flax依然是一个设计良好的Python库。

有了以上解读内容,结合上一篇博客一文打通PyTorch与JAX,相信熟悉使用PyTorch的朋友可以较为简单地上手jax/flax生态里面的代码了。

发布于 2023-10-11 11:53・IP 属地北京
欢迎参与讨论

6 条评论
默认
最新
A ONE

可以出一期jax怎么debug嘛[爱]

01-28 · IP 属地美国
lym

leaves_1

2023-12-27 · IP 属地广东
sssssss

太牛了大佬,清晰易懂,爱看

2023-12-05 · IP 属地北京