一文打通PyTorch与flax (JAX的神经网络库)
上一篇博客一文打通PyTorch与JAX介绍了PyTorch与JAX的底层API之间的关系,打通了PyTorch与JAX的底层联系。然而这还不足以流畅地使用JAX。本文继续深入分析,JAX的神经网络库flax与PyTorch的nn模块之间的关系。
回顾PyTorch nn模块的使用方式
在PyTorch中,我们新建一个模块的方式很简单,只需要继承nn.Module
,并写好__init__
和forward
两个函数:
from torch import nn
class MyMod(nn.Module):
def __init__(self, arg_model):
super().__init__()
self.params = init_params(arg_model)
def forward(self, x):
y = forward_func(self.params, x)
这里的重点在于super().__init__()
函数,它隐藏了非常多的nn.Module
细节。我们可以初始化一个空的模块mod = nn.Module()
,通过dir(mod)
看到nn.Module
非常多的内部状态,包括但不限于_buffers/_parameters/_forward_hooks
等等。
flax Module的初始化方式
JAX的函数式编程最害怕隐藏的状态,因此基本不允许一个没有参数的函数返回一个包含状态的对象。所以,直接初始化flax中的神经网络模块基础类flax.linen.Module
会报错,它的正确初始化方式十分复杂。
在flax中,要新建一个模块,也是继承Module
类,但是我们不能写__init__
函数。继承的方式是通过类型标注增加新的配置参数:
import flax
import typing
class MyMod(flax.linen.Module):
arg_model: typing.Any
flax
的主要工作发生在继承时,它为flax.linen.Module
注册了__init_subclass__
函数,当我们继承flax.linen.Module
得到一个新的类时,这个类已经被flax
改得面目全非了。
最重要的一点修改,就是所有继承的类被自动改成了dataclass
类型:
import dataclasses
dataclasses.is_dataclass(MyMod) # True
其初始化函数MyMod.__init__
,则被改成了接收三个参数的函数:
MyMod.__init__(
self,
arg_model: typing.Any,
parent: Union[Type[flax.linen.module.Module], flax.core.scope.Scope, Type[flax.linen.module._Sentinel], NoneType] = <flax.linen.module._Sentinel object at 0x11a166d40>,
name: Optional[str] = None,
) -> None
其中,第一个参数arg_model
是我们通过类型标注生成的,剩下两个参数是自动生成的。
这里尤其要注意,dataclass
非常依赖类型标注。我们可以不提供参数的默认值,但必须提供类型标注。下面这个例子中,没有提供类型标注,生成的初始化函数就不对了。
class MyMod(flax.linen.Module):
arg_model = 2
MyMod(arg_model=3) # TypeError: MyMod.__init__() got an unexpected keyword argument 'arg_model'
如果不知道什么类型或者不想标注类型,可以像上面的例子一样,直接标注为任意类型Any
。这些类型标注只是用来帮助理解代码与生成函数,并不会对实际运行产生限制(标注为int
类型的参数可以初始化为任意类型的值)。
注意,flax
模块的初始化只是简单地记住了一些构造参数,我们需要在另外的函数里面使用这些参数。
flax Module的计算函数
flax
的计算函数比较特别,参数初始化与计算在同一个函数里面:
import flax
import typing
import jax.numpy as jnp
import jax
class MyMod(flax.linen.Module):
arg_model: typing.Any
@flax.linen.compact
def __call__(self, inputs):
params = self.param(
'whatever',
flax.linen.initializers.uniform(),
(jnp.shape(inputs)[-1], self.arg_model),
jnp.float32,
)
y = inputs @ params
return y
也就是说,在初始化参数时,使用定制的self.param
函数得到参数,然后用这个参数去进行计算,得到输出。注意,这里的参数"whatever"
可以是任意名字,但是不能是"params"
。
有了这个函数,上文一文打通PyTorch与JAX提到的init/apply
等函数都可以用了。
mod = MyMod(arg_model=233)
key1, key2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(key1, (64, 10))
output1, params = mod.init_with_output(key2, x) # Initialization call
output2 = mod.apply(params, x) # forward call
print(((output1 - output2) ** 2).sum()) # they are the same
在flax中使用现有模块
下面展示一个示例,如何在flax中构建一个模块并使用它:
import jax.numpy as jnp
from flax import linen as nn
class ConvLayerNormLeakyReLU(nn.Module):
output_channels: int
kernel_size: int
leaky_slope: float
@nn.compact
def __call__(self, x):
# Convolution layer
x = nn.Conv(
features=self.output_channels,
kernel_size=self.kernel_size,
kernel_init=nn.initializers.lecun_normal()
)(x)
# Layer normalization
x = nn.LayerNorm()(x)
# Leaky ReLU activation
return nn.leaky_relu(x, negative_slope=self.leaky_slope)
input_data = jnp.ones((1, 32, 32, 3)) # Batch of 1, 32x32 image with 3 channels
model = ConvLayerNormLeakyReLU(output_channels=64, kernel_size=(3, 3), leaky_slope=0.01)
output_data, params = model.init_with_output(key, input_data)
output_data = model.apply(params, input_data)
jax/flax
使用的是类似TensorFlow形式的API,模块只包含输出的参数(比如卷积的输出channel数),不包含输入的参数,输入的参数由输入数据直接计算。
还有另一种更加类似PyTorch的使用方式:
import jax.numpy as jnp
from flax import linen as nn
class ConvLayerNormLeakyReLU(nn.Module):
output_channels: int
kernel_size: int
leaky_slope: float
def setup(self):
# Define the convolution and layer normalization sub-modules in setup
self.conv = nn.Conv(
features=self.output_channels,
kernel_size=self.kernel_size,
kernel_init=nn.initializers.lecun_normal()
)
self.layer_norm = nn.LayerNorm()
@nn.compact
def __call__(self, x):
x = self.conv(x) # Apply convolution
x = self.layer_norm(x) # Apply layer normalization
return nn.leaky_relu(x, negative_slope=self.leaky_slope) # Apply leaky ReLU
input_data = jnp.ones((1, 32, 32, 3)) # Batch of 1, 32x32 image with 3 channels
model = ConvLayerNormLeakyReLU(output_channels=64, kernel_size=(3, 3), leaky_slope=0.01)
output_data, params = model.init_with_output(key, input_data)
output_data = model.apply(params, input_data)
以上两种方式基本上是等价的。
有几点需要澄清的地方:
- 不使用
setup
函数时,nn.Conv
生成的模块实际存储在self.Conv_0
里面,Conv_0
是自动生成的名字;使用了setup
函数时,我们手动指定了子模块名conv
。不论是否使用setup
函数,直接输出model.conv
或者model.Conv_0
都会报错,flax
禁止直接访问这些子模块,它们只能在apply/init
等函数调用时被使用。 - 在我们调用
output_data = model.apply(params, input_data)
时,flax
会递归地调用子模块的apply
函数,并且为每个函数的调用维护一个self.scope
,里面记录了每个模块的一些临时状态,用以在创建新的模块时自动记录父模块及子模块的名字,这就是上文中__init__
函数剩余两个参数的意义。
为了说明scope
的作用,我们可以构造嵌套的模块并在模块中打印scope
信息:
import jax
import jax.numpy as jnp
from flax import linen as nn
class Leaf(nn.Module):
slope: float
@nn.compact
def __call__(self, x):
print((self.scope, self.scope.name, self.scope.parent))
return x * self.slope
class Compound(nn.Module):
n: int
def setup(self):
# Create a list of 'Leaf' modules
self.leaves = [Leaf(2.0) for _ in range(self.n)]
def __call__(self, x):
for leaf in self.leaves:
x = leaf(x)
return x
# Example usage:
model = Compound(n=3)
params = model.init(jax.random.PRNGKey(0), jnp.array(1.0))
# Applying Leaf with slope=2.0 for three times, so result should be 8.0 (doubling thrice)
result = model.apply(params, jnp.array(1.0))
print(result)
输出结果:
(<flax.core.scope.Scope object at 0x2982d7d30>, 'leaves_0', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x298377a90>, 'leaves_1', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x298377ca0>, 'leaves_2', <flax.core.scope.Scope object at 0x2983bbb80>)
(<flax.core.scope.Scope object at 0x2983b96c0>, 'leaves_0', <flax.core.scope.Scope object at 0x2982d7d60>)
(<flax.core.scope.Scope object at 0x2983bb550>, 'leaves_1', <flax.core.scope.Scope object at 0x2982d7d60>)
(<flax.core.scope.Scope object at 0x2983bb400>, 'leaves_2', <flax.core.scope.Scope object at 0x2982d7d60>)
从上述代码的输出,我们可以看出,scope
是每个模块在它被调用时创建的临时变量(init/apply
两次调用的scope对象不同),它的parent
指向调用它的父模块对应的scope,name
则为它在父模块中的名字。画成一棵树的样子,大概就是:
Compound model (parent=None, name=None)
|
|-----< parent=Compond
name=leaves_0>
| |
| v
| Leaf
|
|-----< parent=Compond
name=leaves_0>
| |
| v
| Leaf
|
|-----< parent=Compond
name=leaves_2>
|
v
Leaf
实际在运行时,当前存在的scope
总是连成一条链表,对应于当前活跃的模块调用栈。
PyTorch Module 与 flax Module 之间的转换
以上内容可以总结为一张图:

想把PyTorch的Module
转成flax
的代码,需要注意以下五点:
- 使用类型注释来声明构造参数,不得定义
__init__
函数 - 如果模块本身有参数(除了子模块的参数之外),参数初始化与后续查找均用
self.param
函数 - 在
__call__
函数里同时定义参数初始化与计算逻辑,并用@flax.linen.compact
修饰(也可以把子模块的初始化放在setup
函数里) - 模型初始化时需要显式声明随机种子与样例输入
- 后续每次调用时需要输入参数
params
总结
PyTorch框架的核心内容发生在模块对象被创建时的super().__init__()
函数,大部分学习Python的人都能理解初始化时的继承逻辑;而flax
框架的核心工作发生在模块类的继承时,大部分学习Python的人从来没见过继承一个类的时候还能发生这么多事情,所以导致flax
看起来就像是黑魔法。但是本质上来说,flax
使用的技术没有超过Python的范畴,因此我认为flax
依然是一个设计良好的Python库。
有了以上解读内容,结合上一篇博客一文打通PyTorch与JAX,相信熟悉使用PyTorch的朋友可以较为简单地上手jax/flax
生态里面的代码了。
可以出一期jax怎么debug嘛![[爱]](https://pic1.zhimg.com/v2-0942128ebfe78f000e84339fbb745611.png)
leaves_1
太牛了大佬,清晰易懂,爱看
加入TODO list
我也想了解Pallas![[爱]](https://pic1.zhimg.com/v2-0942128ebfe78f000e84339fbb745611.png)