GDL - Steerable CNNs¶
GDL - 可控卷积神经网络 ¶
During the lectures, you have learnt that the symmetries of a machine learning task can be modelled with groups. In the previous tutorial, you have also studied the framework of Group-Convolutional Neural Networks (GCNNs), which describes a neural architecture design equivariant to general groups.
在讲座中,您已经了解到机器学习任务的对称性可以用群来建模。在之前的教程中,您还学习了群卷积神经网络(GCNNs)的框架,该框架描述了一种对一般群等变的神经网络架构设计。
The feature maps of a GCNN are functions over the elements of the group. A naive implementation of group-convolution requires computing and storing a response for each group element. For this reason, the GCNN framework is not particularly convenient to implement networks equivariant to groups with infinite elements.
GCNN 的特征图是群元素上的函数。群卷积的简单实现需要为每个群元素计算和存储响应。因此,GCNN 框架在实现对具有无限元素的群等变的网络时并不特别方便。
Steerable CNNs are a more general framework which solves this issue. The key idea is that, instead of storing the value of a feature map on each group element, the model stores the Fourier transform of this feature map, up to a finite number of frequencies.
可控卷积神经网络是一个更通用的框架,可以解决这个问题。其关键思想是,模型不是在每个群元素上存储特征图的值,而是存储该特征图的傅里叶变换,直到有限数量的频率。
In this tutorial, we will first introduce some Representation theory and Fourier theory (non-commutative harmonic analysis) and, then, we will explore how this idea is used in practice to implement Steerable CNNs.
在本教程中,我们将首先介绍一些表示论和傅里叶理论(非交换谐波分析),然后,我们将探讨如何在实践中使用这一思想来实现 Steerable CNNs。
Prerequisite Knowledge¶ 先决知识
Throughout this tutorial, we will assume you are already familiar with some concepts of group theory, such as groups, group actions (in particular on functions), semi-direct product and order of a group, as well as basic linear algebra.
在整个教程中,我们将假设您已经熟悉一些群论的概念,例如群、群作用(特别是在函数上的作用)、半直积和群的阶,以及基本线性代数。
We start by importing the necessary packages. You can run the following command to install all the requirements:
我们首先导入必要的软件包。您可以运行以下命令来安装所有要求:
> pip install torch torchvision numpy matplotlib git+https://github.com/AMLab-Amsterdam/lie_learn escnn scipy
[1]:
import torch
import numpy as np
import scipy
import os
np.set_printoptions(precision=3, suppress=True, linewidth=10000, threshold=100000)
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
# If the fonts in the plots are incorrectly rendered, comment out the next two lines
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
matplotlib.rcParams['lines.linewidth'] = 2.0
import urllib.request
from urllib.error import HTTPError
CHECKPOINT_PATH = "../../saved_models/DL2/GDL"
/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
/tmp/ipykernel_109/1932627903.py:13: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
set_matplotlib_formats('svg', 'pdf') # For export
[2]:
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
# Files to download
pretrained_files = [
"steerable_c4-pretrained.ckpt",
"steerable_so2-pretrained.ckpt",
"steerable_c4-accuracies.npy",
"steerable_so2-accuracies.npy",
]
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/DL2/GDL/"
# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
file_path = os.path.join(CHECKPOINT_PATH, file_name)
if not os.path.isfile(file_path):
file_url = base_url + file_name
print(f"Downloading {file_url}...")
try:
urllib.request.urlretrieve(file_url, file_path)
except HTTPError as e:
print("Something went wrong. Please contact the author with the full output including the following error:\n", e)
1. Representation Theory and Harmonic Analysis of Compact Groups¶
1. 紧致群的表示理论与调和分析 ¶
We will make use of the escnn
library throughout this tutorial. You can also find its documentation here.
在整个教程中,我们将使用 escnn
库。您也可以在此处找到其文档。
[3]:
try:
from escnn.group import *
except ModuleNotFoundError: # Google Colab does not have escnn installed by default. Hence, we do it here if necessary
!pip install --quiet git+https://github.com/AMLab-Amsterdam/lie_learn escnn
from escnn.group import *
First, let’s create a group. We will use the Cyclic Group escnn
, a groups are instances of the abstract class escnn.group.Group
, which provides some useful functionalities. We instantiate groups via a factory method. To build the cyclic group of order
首先,让我们创建一个群。我们将使用循环群 escnn
中,群是抽象类 escnn.group.Group
的实例,该类提供了一些有用的功能。我们通过工厂方法实例化群。要构建阶为
[4]:
G = cyclic_group(N=8)
# We can verify that the order of this group is 8:
G.order()
[4]:
8
A group is a collection of group elements together with a binary operation to combine them. This is implemented in the class escnn.group.GroupElement
. We can access the identity element
一个群是群元素的集合,并带有一个用于组合它们的二元运算。这在类 escnn.group.GroupElement
中实现。我们可以访问单位元素
[5]:
G.identity
[5]:
0[2pi/8]
or sample a random element as
或随机抽取一个元素作为
[6]:
G.sample()
[6]:
1[2pi/8]
Group elements can be combined via the binary operator @
; we can also take the inverse of an element using ~
:
群元素可以通过二元运算符 @
组合;我们也可以使用 ~
取一个元素的逆:
[7]:
a = G.sample()
b = G.sample()
print(a)
print(b)
print(a @ b)
print(~a)
6[2pi/8]
1[2pi/8]
7[2pi/8]
2[2pi/8]
Representation theory is a fundamental element in Steerable CNNs and to construct a Fourier theory over groups. In this first section, we will introduce the essential concepts.
表示论是可转向卷积神经网络和构建群上的傅里叶理论的基本元素。在第一部分中,我们将介绍基本概念。
1.1 Group Representation¶
1.1 群体表示 ¶
A linear group representation
紧致群
In other words,
换句话说,
Example: the Trivial Representation¶
示例:平凡表示
The simplest example of group representation is the trivial representation which maps every element to
群表示的最简单例子是平凡表示,它将每个元素映射到
[8]:
rho = G.trivial_representation
rho
is an instance of escnn.group.Representation
. This class provides some functionalities to work with group representations. We can also use it as a callable function to compute the representation of a group element; this will return a squared matrix as a numpy.array
. Let verify that the trivial representation does indeed verify the condition above:
rho
是 escnn.group.Representation
的一个实例。这个类提供了一些功能来处理群表示。我们也可以将其用作可调用函数来计算群元素的表示;这将返回一个作为 numpy.array
的方阵。让我们验证平凡表示确实验证了上述条件:
[9]:
g1 = G.sample()
g2 = G.sample()
print(rho(g1) @ rho(g2))
print(rho(g1 @ g2))
[[1.]]
[[1.]]
Note that the trivial representation has size
注意,平凡表示的大小为
[10]:
rho.size
[10]:
1
Example: rotations¶ 示例:旋转
Another common example of group representations is given by 2D rotations. Let
另一个常见的群表示例是二维旋转。设
where
其中
[11]:
G = so2_group()
rho = G.standard_representation()
g1 = G.sample()
g2 = G.sample()
print(f'g1={g1}, g2={g2}, g1 * g2 = {g1 @ g2}')
print()
print('rho(g1) @ rho(g2)')
print(rho(g1) @ rho(g2))
print()
print('rho(g1 * g2)')
print(rho(g1 @ g2))
g1=4.83985258221817, g2=4.721165128388411, g1 * g2 = 3.277832403426995
rho(g1) @ rho(g2)
[[-0.991 0.136]
[-0.136 -0.991]]
rho(g1 * g2)
[[-0.991 0.136]
[-0.136 -0.991]]
QUESTION 1¶ 问题 1 ¶
Show that any representation
证明任何表示
let
be the identity element. Then, is the identity matrix of size .
令 为单位元素。那么, 是大小为 的单位矩阵。let
and be its inverse (i.e. ). Then, .
令 和 为其逆(即 )。然后, 。
ANSWER 1¶ 答案 1 ¶
First question. First, note that for any
第一个问题。首先,注意对于任何
Because
因为
Second question. Note that
第二个问题。请注意
Using the fact
利用
Direct Sum¶ 直和
We can combine representations to build a larger representation via the direct sum.
我们可以通过直和将表示结合起来构建更大的表示。
Given representations
给定表示
Its action is therefore given by the independent actions of
因此,其作用由
Let’s see an example:
让我们看一个例子:
[12]:
rho_sum = rho + rho
g = G.sample()
print(rho(g))
print()
print(rho_sum(g))
[[-0.943 -0.332]
[ 0.332 -0.943]]
[[-0.943 -0.332 0. 0. ]
[ 0.332 -0.943 0. 0. ]
[ 0. 0. -0.943 -0.332]
[ 0. 0. 0.332 -0.943]]
Note that the direct sum of two representations has size equal to the sum of their sizes:
注意,两个表示的直和的大小等于它们大小的总和:
[13]:
rho.size, rho_sum.size
[13]:
(2, 4)
We can combine arbitrary many representations in this way, e.g.
我们可以通过这种方式组合任意多的表示,例如
[14]:
rho_sum = rho + rho + rho + rho
# or, more simply:
rho_sum = directsum([rho, rho, rho, rho])
rho_sum.size
[14]:
8
The Regular Representation¶
正则表示
Another important representation is the regular representation. The regular representation describes the action of a group
另一个重要的表示是正则表示。正则表示描述了群
The set of functions over
在
The regular representation of
The new function
新函数
QUESTION 2¶ 问题 2 ¶
Show that the space of functions over
证明在
ANSWER 2¶ 答案 2 ¶
Let
设
The scalar multiplication is also defined pointwise as
标量乘法也被逐点定义为
We now verify the required properties of a vector space.
我们现在验证向量空间所需的性质。
associativity:
结合性:commutativity:
交换性:identity: define the function
;
身份:定义函数 ;inverse: define
; then
逆: 定义 ; 然后compatibility:
兼容性:identity (multiplication):
恒等(乘法):distributivity (vector):
分配率 (vector):distributivity (scalar):
分配率(标量):
For finite groups, we can generate this representation. We assume that the
对于有限群,我们可以生成这个表示。我们假设第
[15]:
G = cyclic_group(8)
rho = G.regular_representation
[16]:
# note that the size of the representation is equal to the group's order |G|
rho.size
[16]:
8
the identity element maps a function to itself, so the entries are not permuted
恒等元素将函数映射到自身,因此条目不被置换
[17]:
rho(G.identity)
[17]:
array([[ 1., 0., -0., 0., -0., -0., 0., -0.],
[ 0., 1., 0., -0., -0., -0., -0., -0.],
[-0., 0., 1., -0., -0., 0., -0., 0.],
[ 0., -0., -0., 1., 0., -0., -0., -0.],
[-0., -0., -0., 0., 1., 0., -0., -0.],
[-0., -0., 0., -0., 0., 1., -0., 0.],
[ 0., -0., -0., -0., -0., -0., 1., -0.],
[-0., -0., 0., -0., -0., 0., -0., 1.]])
The regular representation of the rotation by
通过
[18]:
rho(G.element(1))
[18]:
array([[ 0., -0., 0., -0., -0., 0., -0., 1.],
[ 1., 0., -0., -0., -0., -0., 0., -0.],
[-0., 1., 0., -0., 0., -0., -0., -0.],
[-0., 0., 1., 0., -0., 0., -0., 0.],
[-0., -0., -0., 1., 0., -0., 0., -0.],
[-0., 0., -0., 0., 1., 0., 0., -0.],
[-0., -0., -0., -0., 0., 1., 0., 0.],
[-0., 0., -0., -0., 0., -0., 1., 0.]])
Let’s see an example of the action on a function. We consider a function which is zero on all group elements apart from the identity (
让我们看看一个函数上的动作例子。我们考虑一个在除单位元 (
[19]:
f = np.zeros(8)
f[0] = 1
f
[19]:
array([1., 0., 0., 0., 0., 0., 0., 0.])
Observe that
注意
[20]:
rho(G.identity) @ f
[20]:
array([ 1., 0., -0., 0., -0., -0., 0., -0.])
[21]:
rho(G.element(1)) @ f
[21]:
array([ 0., 1., -0., -0., -0., -0., -0., -0.])
[22]:
rho(G.element(6)) @ f
[22]:
array([ 0., -0., 0., -0., -0., -0., 1., -0.])
QUESTION 3¶ 问题 3 ¶
Prove the result above.
证明上述结果。
ANSWER 3¶ 答案 3 ¶
Let’s call
我们称
which is zero everywhere apart from
除了
We now want to show that
我们现在想要证明
Equivalent Representations¶
等效表示
Two representations
如果且仅当它们通过基变换
Equivalent representations behave similarly since their composition is basis-independent as seen by
等价表示表现相似,因为它们的组成是基于独立的基础,如所见
Direct sum and change of basis matrices provide a way to combine representations to construct larger and more complex representations. In the next example, we concatenate two trivial representations and two regular representations and apply a random change of basis
直和和基变换矩阵提供了一种组合表示以构建更大和更复杂表示的方法。在下一个例子中,我们连接两个平凡表示和两个正则表示,并应用一个随机基变换
[23]:
d = G.trivial_representation.size * 2 + G.regular_representation.size * 2
Q = np.random.randn(d, d)
rho = directsum(
[G.trivial_representation, G.regular_representation, G.regular_representation, G.trivial_representation],
change_of_basis=Q
)
[24]:
rho.size
[24]:
18
Irreducible Representations (or Irreps)¶
不可约表示(或 Irreps)
Under minor conditions, any representation can be decomposed in this way, that is, any representation
在较小的条件下,任何表示都可以以这种方式分解,即任何紧群
The set of irreducible representations of a group
一个群的不可约表示集
We can access the irreps of a group via the irrep()
method. The trivial representation is always an irreducible representation. For
我们可以通过 irrep()
方法访问一个群的不可约表示。平凡表示总是一个不可约表示。对于
[25]:
rho_0 = G.irrep(0)
print(rho_0 == G.trivial_representation)
rho_0(G.sample())
True
[25]:
array([[1.]])
The next irrep
下一个不可约表示
[26]:
rho = G.irrep(1)
g = G.sample()
print(g)
print()
print(rho(g))
1[2pi/8]
[[ 0.707 -0.707]
[ 0.707 0.707]]
Irreducible representations provide the building blocks to construct any representation
不可约表示提供了构建任何表示的基本单元
where
其中
Internally, any escnn.group.Representation
is indeed implemented as a list of irreps (representing the index set id
.
在内部,任何 escnn.group.Representation
实际上都实现为不可约表示的列表(表示索引集 id
标识。
Let’s see an example. Let’s take the regular representaiton of
让我们看一个例子。让我们取
[27]:
rho = G.regular_representation
rho.irreps
[27]:
[(0,), (1,), (2,), (3,), (4,)]
[28]:
rho.change_of_basis
[28]:
array([[ 0.354, 0.5 , 0. , 0.5 , 0. , 0.5 , 0. , 0.354],
[ 0.354, 0.354, 0.354, 0. , 0.5 , -0.354, 0.354, -0.354],
[ 0.354, 0. , 0.5 , -0.5 , 0. , -0. , -0.5 , 0.354],
[ 0.354, -0.354, 0.354, -0. , -0.5 , 0.354, 0.354, -0.354],
[ 0.354, -0.5 , 0. , 0.5 , -0. , -0.5 , 0. , 0.354],
[ 0.354, -0.354, -0.354, 0. , 0.5 , 0.354, -0.354, -0.354],
[ 0.354, -0. , -0.5 , -0.5 , 0. , 0. , 0.5 , 0.354],
[ 0.354, 0.354, -0.354, -0. , -0.5 , -0.354, -0.354, -0.354]])
[29]:
# let's access second irrep
rho_id = rho.irreps[1]
rho_1 = G.irrep(*rho_id)
# we verify it is the irrep j=1 we described before
rho_1(g)
[29]:
array([[ 0.707, -0.707],
[ 0.707, 0.707]])
Finally, let’s verify that this direct sum and this change of basis indeed yield the regular representation
最后,让我们验证这个直和和这个基变换确实产生了正则表示
[30]:
# evaluate all the irreps in rho.irreps:
irreps = [
G.irrep(*irrep)(g) for irrep in rho.irreps
]
# build the direct sum
direct_sum = np.asarray(scipy.sparse.block_diag(irreps, format='csc').todense())
print('Regular representation of', g)
print(rho(g))
print()
print('Direct sum of the irreps:')
print(direct_sum)
print()
print('Apply the change of basis on the direct sum of the irreps:')
print(rho.change_of_basis @ direct_sum @ rho.change_of_basis_inv)
print()
print('Are the two representations equal?', np.allclose(rho(g), rho.change_of_basis @ direct_sum @ rho.change_of_basis_inv))
Regular representation of 1[2pi/8]
[[ 0. -0. 0. -0. -0. 0. -0. 1.]
[ 1. 0. -0. -0. -0. -0. 0. -0.]
[-0. 1. 0. -0. 0. -0. -0. -0.]
[-0. 0. 1. 0. -0. 0. -0. 0.]
[-0. -0. -0. 1. 0. -0. 0. -0.]
[-0. 0. -0. 0. 1. 0. 0. -0.]
[-0. -0. -0. -0. 0. 1. 0. 0.]
[-0. 0. -0. -0. 0. -0. 1. 0.]]
Direct sum of the irreps:
[[ 1. 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0.707 -0.707 0. 0. 0. 0. 0. ]
[ 0. 0.707 0.707 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. -1. 0. 0. 0. ]
[ 0. 0. 0. 1. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. -0.707 -0.707 0. ]
[ 0. 0. 0. 0. 0. 0.707 -0.707 0. ]
[ 0. 0. 0. 0. 0. 0. 0. -1. ]]
Apply the change of basis on the direct sum of the irreps:
[[ 0. -0. 0. -0. -0. 0. -0. 1.]
[ 1. 0. -0. 0. -0. -0. 0. -0.]
[-0. 1. 0. -0. 0. -0. -0. -0.]
[-0. 0. 1. 0. -0. 0. -0. 0.]
[-0. -0. -0. 1. 0. -0. 0. -0.]
[-0. 0. -0. 0. 1. 0. 0. -0.]
[-0. -0. -0. -0. 0. 1. 0. 0.]
[-0. 0. -0. -0. 0. -0. 1. 0.]]
Are the two representations equal? True
1.2 Fourier Transform¶ 1.2 傅里叶变换 ¶
We can finally approach the harmonic analysis of functions over a group
我们终于可以研究群
Note that a representation
注意,表示
This result gives us a way to parameterize functions over the group. This is the focus of this section. In particular, this is useful to parameterize functions over groups with infinite elements.
这个结果为我们提供了一种对群上的函数进行参数化的方法。这是本节的重点。特别是,这对于对具有无限元素的群上的函数进行参数化非常有用。
In this section, we will first consider the dihedral group
在本节中,我们将首先考虑二面体群
[31]:
G = dihedral_group(8)
G.order()
[31]:
16
[32]:
# element representing the reflection (-) and no rotations
G.reflection
[32]:
(-, 0[2pi/8])
[33]:
# element representing a rotation by pi/2 (i.e. 2 * 2pi/8) and no reflections (+)
G.element((0, 2))
[33]:
(+, 2[2pi/8])
[34]:
# reflection followed by a rotation by pi/2
print(G.element((0, 2)) @ G.reflection)
# we can also directly generate this element as
print(G.element((1, 2)))
(-, 2[2pi/8])
(-, 2[2pi/8])
[35]:
# a rotation by pi/2 followed by a reflection is equivalent to a reclection followed by a rotation by 6*2pi/8
G.reflection @ G.element((0, 2))
[35]:
(-, 6[2pi/8])
The list of all elements in the group is obtaied as:
组中所有元素的列表如下所示:
[36]:
G.elements
[36]:
[(+, 0[2pi/8]),
(+, 1[2pi/8]),
(+, 2[2pi/8]),
(+, 3[2pi/8]),
(+, 4[2pi/8]),
(+, 5[2pi/8]),
(+, 6[2pi/8]),
(+, 7[2pi/8]),
(-, 0[2pi/8]),
(-, 1[2pi/8]),
(-, 2[2pi/8]),
(-, 3[2pi/8]),
(-, 4[2pi/8]),
(-, 5[2pi/8]),
(-, 6[2pi/8]),
(-, 7[2pi/8])]
Fourier and Inverse Fourier Transform¶
傅里叶和逆傅里叶变换 ¶
For most groups, the entries of the irreps don’t only span the space of functions but form also a basis (i.e. these functions are mutually orthogonal to each other). Therefore, we can write a function
对于大多数群,不可约表示的条目不仅跨越函数空间,而且还形成一个基(即这些函数彼此正交)。因此,我们可以将函数
where
其中
We rewrite this expression in a cleaner form by using the following fact. If
我们通过使用以下事实将此表达式重写为更简洁的形式。如果
By definining
通过将
Similarly, we can project a general function
同样,我们可以通过以下方式将一般函数
The projection over all entries of
将
which we refer to as Fourier Transform.
我们称之为傅里叶变换。
If the group
如果群
For a finite group Group.irreps()
method. Let’s see an example:
对于有限群 Group.irreps()
方法来获取其所有不可约表示。让我们来看一个例子:
[37]:
irreps = G.irreps()
print(f'The dihedral group D8 has {len(irreps)} irreps')
The dihedral group D8 has 7 irreps
[38]:
# the first one, is the 1-dimensional trivial representation
print(irreps[0] == G.trivial_representation == G.irrep(0, 0))
True
QUESTION 4¶ 问题 4 ¶
We can now implement the Fourier Transform and the Inverse Fourier Transform for the Dihedral Group
我们现在可以为二面体群
[39]:
def fourier_transform_D8(f: np.array):
# the method gets in input a function on the elements of D_8
# and should return a dictionary mapping each irrep's `id` to the corresponding Fourier Transform
# The i-th element of `f` stores the value of the function on the group element `G.elements[i]`
G = dihedral_group(8)
assert f.shape == (16,), f.shape
ft = {}
########################
# INSERT YOUR CODE HERE:
for rho in G.irreps():
d = rho.size
rho_g = np.stack([rho(g) for g in G.elements], axis=0)
ft[rho.id] = (f.reshape(-1, 1, 1) * rho_g).mean(0) * np.sqrt(d)
########################
return ft
[40]:
def inverse_fourier_transform_D8(ft: dict):
# the method gets in input a dictionary mapping each irrep's `id` to the corresponding Fourier Transform
# and should return the function `f` on the elements of D_8
# The i-th element of `f` stores the value of the function on the group element `G.elements[i]`
G = dihedral_group(8)
f = np.zeros(16)
########################
# INSERT YOUR CODE HERE:
for rho in G.irreps():
d = rho.size
for i, g in enumerate(G.elements):
f[i] += np.sqrt(d) * (ft[rho.id] * rho(g)).sum()
########################
return f
We now want to verify that the Fourier Transform and the Inverse Fourier Transform are inverse of each other:
我们现在想要验证傅里叶变换和逆傅里叶变换是彼此的逆:
[41]:
f = np.random.randn(16)
ft = fourier_transform_D8(f)
new_f = inverse_fourier_transform_D8(ft)
assert np.allclose(f, new_f)
Parameterizing functions over infinite groups¶
对无限群进行函数参数化
This allows us to also parameterize functions over infinite groups, such as
这也使我们能够对无限群的函数进行参数化,例如
[42]:
G = o2_group()
[43]:
# the group has infinite many elements, so the `order` method just returns -1
G.order()
[43]:
-1
The equations remain the same, but this group has an infinite number of irreps. We can, however, parameterize a function over the group by only considering a finite number of irreps in the sum inside the definition of Inverse Fourier Transform. Let
方程保持不变,但该群有无限多个不可约表示。然而,我们可以通过在逆傅里叶变换的定义中仅考虑有限个不可约表示来对群上的函数进行参数化。设
Inverse Fourier Transform:
逆傅里叶变换:
and Fourier Transform: 和傅里叶变换:
QUESTION 5¶ 问题 5 ¶
We can now implement the Inverse Fourier Transform for the Orthogonal Group
我们现在可以为正交群
[44]:
def inverse_fourier_transform_O2(g: GroupElement, ft: dict):
# the method gets in input a dictionary mapping each irrep's `id` to the corresponding Fourier Transform
# and a group element `g`
# The method should return the value of the function evaluated on `g`.
G = o2_group()
f = 0
########################
# INSERT YOUR CODE HERE:
for rho, ft_rho in ft.items():
rho = G.irrep(*rho)
d = rho.size
f += np.sqrt(d) * (ft_rho * rho(g)).sum()
########################
return f
Let’s plot a function. First we generate a random function by using a few irreps.
让我们绘制一个函数。首先,我们通过使用一些不可约表示生成一个随机函数。
[45]:
irreps = [G.irrep(0, 0)] + [G.irrep(1, j) for j in range(3)]
ft = {
rho.id: np.random.randn(rho.size, rho.size)
for rho in irreps
}
Then, we generate a grid on the group where to evaluate the function, i.e. we choose a finite set of element of
然后,我们在群上生成一个网格以评估函数,即我们选择
[46]:
G.sample()
[46]:
(+, 0.026961821470776897)
To build our grid, we sample
为了构建我们的网格,我们采样
[47]:
N = 100
thetas = [i*2*np.pi/N for i in range(N)]
grid_rot = [G.element((0, theta)) for theta in thetas]
grid_refl = [G.element((1, theta)) for theta in thetas]
We now evaluate the function over all these elements and, finally, plot it:
我们现在对所有这些元素评估该函数,最后绘制它:
[48]:
f_rot = [
inverse_fourier_transform_O2(g, ft) for g in grid_rot
]
f_refl = [
inverse_fourier_transform_O2(g, ft) for g in grid_refl
]
plt.plot(thetas, f_rot, label='rotations')
plt.plot(thetas, f_refl, label='reflection + rotations')
plt.xlabel('theta [0, 2pi)')
plt.ylabel('f(g)')
plt.legend()
plt.show()
Observe that using more irreps allows one to parameterize more flexible functions. Let’s try to add some more:
注意,使用更多的不可约表示可以参数化更灵活的函数。让我们尝试添加更多:
[49]:
irreps = [G.irrep(0, 0)] + [G.irrep(1, j) for j in range(8)]
ft = {
rho.id: np.random.randn(rho.size, rho.size)
for rho in irreps
}
f_rot = [
inverse_fourier_transform_O2(g, ft) for g in grid_rot
]
f_refl = [
inverse_fourier_transform_O2(g, ft) for g in grid_refl
]
plt.plot(thetas, f_rot, label='rotations')
plt.plot(thetas, f_refl, label='reflection + rotations')
plt.xlabel('theta [0, 2pi)')
plt.ylabel('f(g)')
plt.legend()
plt.show()
Fourier Transform of shifted functions¶
移位函数的傅里叶变换
Recall that a group element
回忆一下,一个群元素
The Fourier transform defined before has the convenient property that the Fourier transform of
之前定义的傅里叶变换具有一个方便的性质,即
for any irrep
对于任何不可约表示
QUESTION 6¶ 问题 6 ¶
Prove the property above.
证明上述性质。
ANSWER 6¶ 答案 6 ¶
We can verify this property visually:
我们可以直观地验证这个属性:
[50]:
irreps = [G.irrep(0, 0)] + [G.irrep(1, j) for j in range(8)]
# first, we generate a random function, as earlier
ft = {
rho.id: np.random.randn(rho.size, rho.size)
for rho in irreps
}
# second, we sample a random group element `g`
g = G.sample()
print(f'Transforming the function with g={g}')
# finally, we transform the Fourier coefficients as in the equations above:
gft = {
rho.id: rho(g) @ ft[rho.id]
for rho in irreps
}
# Let's now visualize the two functions:
f_rot = [
inverse_fourier_transform_O2(g, ft) for g in grid_rot
]
f_refl = [
inverse_fourier_transform_O2(g, ft) for g in grid_refl
]
gf_rot = [
inverse_fourier_transform_O2(g, gft) for g in grid_rot
]
gf_refl = [
inverse_fourier_transform_O2(g, gft) for g in grid_refl
]
plt.plot(thetas, f_rot, label='rotations')
plt.plot(thetas, f_refl, label='reflection + rotations')
plt.xlabel('theta [0, 2pi)')
plt.ylabel('f(g)')
plt.title('f')
plt.legend()
plt.show()
plt.plot(thetas, gf_rot, label='rotations')
plt.plot(thetas, gf_refl, label='reflection + rotations')
plt.xlabel('theta [0, 2pi)')
plt.ylabel('f(g)')
plt.title('g.f')
plt.legend()
plt.show()
Transforming the function with g=(+, 0.4933335011719361)
From the Fourier Transform to the Regular Representation¶
从傅里叶变换到正则表示
For simplicity, we can stack all the Fourier coefficients (the output of the Fourier transform, that is, the input of the inverse Fourier transform) into a unique vector. We define the vector
为简单起见,我们可以将所有傅里叶系数(傅里叶变换的输出,即逆傅里叶变换的输入)堆叠成一个唯一的向量。我们将向量
Let’s first introduce some notation. We denote the stack of two vectors
首先介绍一些符号。我们将两个向量的堆栈
Second, we denote by numpy
, this is written as A.T.reshape(-1)
, where the transpose is necessary since numpy
stacks rows by default.
其次,我们用 numpy
中,这被写为 A.T.reshape(-1)
,其中转置是必要的,因为 numpy
默认按行堆叠。
Then, we write: 然后,我们写:
Moreover, by using
此外,通过使用
In other words, the group
换句话说,群
i.e.
即
Note that, essentially, the representation
注意,实质上,表示
where each irrep
每个不可约表示
Intuition: recall that a function
直觉:回想一下,函数
Let verify this equivalence for the Dihdral group
让我们验证这个等价性对于二面体群
[51]:
G = dihedral_group(8)
rho_irreps = []
for rho_j in G.irreps():
d_j = rho_j.size
# repeat each irrep a number of times equal to its size
rho_irreps += [rho_j]*d_j
rho = directsum(rho_irreps)
print('The representations have the same size:')
print(rho.size, G.regular_representation.size)
print('And contain the same irreps:')
print(rho.irreps)
print(G.regular_representation.irreps)
# Fourier transform matrix:
Q = G.regular_representation.change_of_basis
# inverse Fourier transform matrix:
Qinv = G.regular_representation.change_of_basis_inv
# let's check that the two representations are indeed equivalent
g = G.sample()
rho_g = rho(g)
reg_g = G.regular_representation(g)
print()
print('Are the two representations equivalent?', np.allclose(Q @ rho_g @ Qinv, reg_g))
The representations have the same size:
16 16
And contain the same irreps:
[(0, 0), (1, 0), (1, 1), (1, 1), (1, 2), (1, 2), (1, 3), (1, 3), (1, 4), (0, 4)]
[(0, 0), (1, 0), (1, 1), (1, 1), (1, 2), (1, 2), (1, 3), (1, 3), (1, 4), (0, 4)]
Are the two representations equivalent? True
When
当
This is the underlying idea we will exploit later to build GCNNs equivariant to infinite groups.
这是我们稍后将利用的基本思想,以构建对无限群等变的 GCNNs。
We can easily generate this representation as (bl_regular_representation
stands for “band-limited”, since only a limited subset of irreps, i.e. frequencies, is used):
我们可以轻松生成这种表示( bl_regular_representation
代表“带限”,因为只使用了有限的不可约表示子集,即频率):
[52]:
G = o2_group()
rho = G.bl_regular_representation(7)
rho.irreps
[52]:
[(0, 0),
(1, 0),
(1, 1),
(1, 1),
(1, 2),
(1, 2),
(1, 3),
(1, 3),
(1, 4),
(1, 4),
(1, 5),
(1, 5),
(1, 6),
(1, 6),
(1, 7),
(1, 7)]
Irreps with redundant entries: the case of ¶
具有冗余条目的不可约表示: 的情况 ¶
We need to conclude with a final note about the Fourier transform. When we introduced it earlier, we said that the entries of the irreps form a basis for the functions over most groups. Indeed, there exists some groups where the entries of the irreps are partially redundant and, therefore, form an overcomplete basis. This is the case, for example, of the group of planar rotations
我们需要以关于傅里叶变换的最后一点来结束。当我们之前介绍它时,我们说过不可约表示的条目构成了大多数群上函数的基。确实,存在一些群,其中不可约表示的条目部分冗余,因此构成了一个过完备基。例如,这是平面旋转群
for
对于
You can observe that the two columns of
你可以观察到,
We don’t generally need to worry much about this, since we can generate the representation as earlier:
我们通常不需要过多担心这一点,因为我们可以像之前一样生成表示:
[53]:
G = so2_group()
rho = G.bl_regular_representation(7)
# observe that each irrep is now repeated only once, even if some are 2-dimensional
rho.irreps
[53]:
[(0,), (1,), (2,), (3,), (4,), (5,), (6,), (7,)]
2. From Group CNNs to Steerable CNNs¶
2. 从群 CNN 到可引导 CNN ¶
We consider a GCNN equivariant to a semi-direct product group
我们考虑一个对半直积群
If
如果
2.1 Feature Fields¶ 2.1 特性字段 ¶
In a GCNN, a feature map is a signal
在 GCNN 中,特征图是信号
where
其中
QUESTION 7¶ 问题 7 ¶
Prove the action has indeed this form.
证明该动作确实具有这种形式。
ANSWER 7¶ 答案 7 ¶
First, recall the group law: for any
首先,回忆群运算:对于任何
where
其中
In a GCNN, a feature map
在 GCNN 中,特征图
In a steerable CNN, we replace the
在可控 CNN 中,我们用“傅里叶”轴替换
A feature map
特征图
which assigns a
其为每个空间位置
The action of
where
其中
QUESTION 8¶ 问题 8 ¶
Prove that this is indeed the right action of
证明这确实是
ANSWER 8:¶ 答案 8: ¶
We know from the previous question that
我们从上一个问题中知道,
Recall also that
还记得
Note that in the equations above, the square brakets in
请注意,在上述方程中,
General Steerable CNNs¶ 通用可控卷积神经网络 ¶
The framework of Steerable CNNs is actually more general and allows for any representation
可控卷积神经网络的框架实际上更为通用,允许任何
Throughout the rest of this tutorial, we will assume
在本教程的其余部分中,为了简单起见,我们将假设
2.2 Defining a Steerable CNN¶
2.2 定义一个可控卷积神经网络 ¶
We can now proceed with building a Steerable CNN. First we import some other useful packages.
我们现在可以继续构建一个 Steerable CNN。首先,我们导入一些其他有用的包。
[54]:
from escnn import group
from escnn import gspaces
from escnn import nn
First, we need to choose the group
首先,我们需要选择所考虑的点对称(反射和旋转)群
For simplicity, we first consider the finite group
为简单起见,我们首先考虑有限群
Recall that a semi-direct product gspace.GSpace
. For the rotational action of
请记住,半直积 gspace.GSpace
的一个子类来确定点群
[55]:
r2_act = gspaces.rot2dOnR2(N=4)
r2_act
[55]:
C4_on_R2[(None, 4)]
[56]:
# we can access the group G as
G = r2_act.fibergroup
G
[56]:
C4
Having specified the symmetry transformation on the base space nn.FieldType
.
在基空间 nn.FieldType
实现。
We instantiate the nn.FieldType
modeling a GCNN feature by passing it the gspaces.GSpace
instance and the regular representation of
我们通过传递 gspaces.GSpace
实例和 nn.FieldType
,以建模 GCNN 特征。我们称与常规表示
[57]:
feat_type = nn.FieldType(r2_act, [G.regular_representation])
feat_type
[57]:
[C4_on_R2[(None, 4)]: {regular (x1)}(4)]
Recall that the regular representation of a finite group G.regular_representation
is a permutation matrix of shape
回忆一下,由 G.regular_representation
构建的有限群
[58]:
G.regular_representation(G.sample())
[58]:
array([[ 1., 0., -0., -0.],
[ 0., 1., 0., -0.],
[-0., 0., 1., 0.],
[-0., -0., 0., 1.]])
Deep Feature spaces¶ 深度特征空间
The deep feature spaces of a GCNN typically comprise multiple channels. Similarly, the feature spaces of a steerable CNN can include multiple independent feature fields. This is achieved via direc sum, but stacking multiple copies of
GCNN 的深度特征空间通常包含多个通道。同样,可控 CNN 的特征空间可以包括多个独立的特征场。这是通过直接和实现的,但堆叠多个
For example, we can use
例如,我们可以使用
We instantiate a nn.FieldType
composed of
我们通过将完整字段表示作为三个常规表示的列表传递,实例化一个由 nn.FieldType
:
[59]:
# Technically, one can also construct the direct-sum representation G.regular_representation + G.regular_representation + G.regular_representation as done
# before. Passing a list containing 3 copies of G.regular_representation allows for more efficient implementation of certain operations internally.
feat_type = nn.FieldType(r2_act, [G.regular_representation]*3)
feat_type
[59]:
[C4_on_R2[(None, 4)]: {regular (x3)}(12)]
Input Features¶ 输入特征
Each hidden layer of a steerable CNN has its own transformation law which the user needs to specify (equivalent to the choice of number of channels in each layer of a conventional CNN). The input and output of a steerable CNN are also feature fields and their type (i.e. transformation law) is typically determined by the inference task.
可操控 CNN 的每个隐藏层都有其自身的变换法则,用户需要指定(相当于在传统 CNN 中选择每层的通道数)。可操控 CNN 的输入和输出也是特征场,其类型(即变换法则)通常由推理任务决定。
The most common example is that of gray-scale input images. A rotation of a gray-scale image is performed by moving each pixel to a new position without changing their intensity values. The invariance of the scalar pixel values under rotations is modeled by the trivial representation
最常见的例子是灰度输入图像。灰度图像的旋转是通过将每个像素移动到新位置而不改变其强度值来进行的。标量像素值在旋转下的不变性由
We instantiate the nn.FieldType
modeling a gray-scale image by passing it the trivial representation of
我们通过传递 nn.FieldType
以建模灰度图像:
[60]:
feat_type_in = nn.FieldType(r2_act, [G.trivial_representation])
feat_type_in
[60]:
[C4_on_R2[(None, 4)]: {irrep_0 (x1)}(1)]
Equivariant Layers¶ 等变层
When we build a model equivariant to a group
当我们构建一个与群
where feat_type_in
we have just defined above precisely describes feat_type_out
of nn.FieldType
.
其中 feat_type_in
准确描述了 nn.FieldType
的一个实例 feat_type_out
来选择。
For example, let’s use
例如,让我们在输出中使用
[61]:
feat_type_out = nn.FieldType(r2_act, [G.regular_representation]*3)
As a shortcut, we can also use:
作为快捷方式,我们也可以使用:
[62]:
feat_type_in = nn.FieldType(r2_act, [r2_act.trivial_repr])
feat_type_out = nn.FieldType(r2_act, [r2_act.regular_repr]*3)
Once having defined how the input and output feature spaces should transform, we can build neural network functions as equivariant modules. These are implemented as subclasses of an abstract base class nn.EquivariantModule
which itself inherits from torch.nn.Module
.
一旦定义了输入和输出特征空间应如何转换,我们就可以将神经网络函数构建为等变模块。这些实现为抽象基类 nn.EquivariantModule
的子类,而该基类本身继承自 torch.nn.Module
。
Equivariant Convolution Layer: We start by instantiating a convolutional layer that maps between fields of types feat_type_in
and feat_type_out
.
等变卷积层:我们首先实例化一个卷积层,该层在类型 feat_type_in
和 feat_type_out
的场之间进行映射。
Let feat_type_in
and feat_type_out
. Then, an equivariant convolution layer is a standard convolution layer with a filter
令 feat_type_in
和 feat_type_out
相关联的
In particular, the use of convolution guarantees the translation equivariance, while the fact the filters satisfy this steerability constraint guarantees the
特别是,卷积的使用保证了平移等变性,而滤波器满足这种可控性约束则保证了
QUESTION 9¶ 问题 9 ¶
Show that if a filter
如果一个滤波器
for all
The action on the features
对特征
and 和
while the convolution is defined as
卷积定义为
ANSWER 9¶ 答案 9 ¶
Note that, because
请注意,由于
The steerability constraint restricts the space of possible learnable filters to a smaller space of equivariant filters. Solving this constraint goes beyond the scope of this tutorial; fortunately, the nn.R2Conv
module takes care of properly parameterizing the filter
可控性约束将可能的可学习滤波器的空间限制为较小的等变滤波器空间。解决此约束超出了本教程的范围;幸运的是, nn.R2Conv
模块负责正确参数化滤波器
[63]:
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=3)
Each equivariant module has an input and output type. As a function (.forward()
), it requires its inputs to transform according to its input type and is guaranteed to return feature fields associated with its output type. To prevent the user from accidentally feeding an incorrectly transforming input field into an equivariant module, we perform a dynamic type checking. In order to do so, we define geometric tensors as data containers. They are wrapping a PyTorch torch.Tensor
to
augment them with an instance of FieldType
.
每个等变模块都有一个输入和输出类型。作为一个函数( .forward()
),它要求其输入根据其输入类型进行变换,并保证返回与其输出类型相关的特征场。为了防止用户意外地将错误变换的输入字段输入到等变模块中,我们执行动态类型检查。为此,我们将几何张量定义为数据容器。它们包装了一个 PyTorch torch.Tensor
,以用 FieldType
的实例增强它们。
Let’s build a few random 32x32 gray-scale images and wrap them into an nn.GeometricTensor
:
让我们构建一些随机的 32x32 灰度图像并将它们包装成一个 nn.GeometricTensor
:
[64]:
x = torch.randn(4, 1, 32, 32)
# FieldType is a callable object; its call method can be used to wrap PyTorch tensors into GeometricTensors
x = feat_type_in(x)
assert isinstance(x.tensor, torch.Tensor)
assert isinstance(x, nn.GeometricTensor)
As usually done in PyTorch, an image or feature map is stored in a 4-dimensional array of shape BxCxHxW, where B is the batch-size, C is the number of channels and W and H are the spatial dimensions.
通常在 PyTorch 中,图像或特征图存储在形状为 BxCxHxW 的四维数组中,其中 B 是批量大小,C 是通道数,W 和 H 是空间维度。
We can feed a geometric tensor to an equivariant module as we feed normal tensors in PyTorch’s modules:
我们可以像在 PyTorch 的模块中输入普通张量一样,将几何张量输入到等变模块中:
[65]:
y = conv(x)
We can verify that the output is indeed associated with the output type of the convolutional layer:
我们可以验证输出确实与卷积层的输出类型相关联:
[66]:
assert y.type == feat_type_out
Lets check whether the output transforms as described by the output type when the input transforms according to the input type. The nn.GeometricTensor.transform()
.
让我们检查当输入根据输入类型转换时,输出是否按照输出类型描述进行转换。几何张量的 nn.GeometricTensor.transform()
方便地完成。
[67]:
# for each group element
for g in G.elements:
# transform the input with the current group element according to the input type
x_transformed = x.transform(g)
# feed the transformed input in the convolutional layer
y_from_x_transformed = conv(x_transformed)
# the result should be equivalent to rotating the output produced in the
# previous block according to the output type
y_transformed_from_x = y.transform(g)
assert torch.allclose(y_from_x_transformed.tensor, y_transformed_from_x.tensor, atol=1e-5), g
Any network operation is required to be equivariant. escnn.nn
provides a wide range of equivariant network modules which guarantee this behavior.
任何网络操作都需要具有等变性。 escnn.nn
提供了多种等变网络模块以保证这种行为。
Non-Linearities: As an example, we will next apply an equivariant nonlinearity to the output feature field of the convolution. Since the regular representations of a finite group
非线性:作为一个例子,我们接下来将对卷积的输出特征域应用一个等变非线性。由于有限群
We instantiate a escnn.nn.ReLU
, which, as an nn.EquivariantModule
, requires to be informed about its input type to be able to perform the type checking. Here we are passing feat_type_out
, the output of the equivariant convolution layer, as input type. It is not necessary to pass an output type to the nonlinearity since this is here determined by its input type.
我们实例化一个 escnn.nn.ReLU
,作为一个 nn.EquivariantModule
,需要了解其输入类型以便能够执行类型检查。这里我们传递 feat_type_out
,即等变卷积层的输出,作为输入类型。由于非线性在这里由其输入类型决定,因此不需要传递输出类型。
[68]:
relu = nn.ReLU(feat_type_out)
z = relu(y)
We can verify the equivariance again:
我们可以再次验证等变性:
[69]:
# for each group element
for g in G.elements:
y_transformed = y.transform(g)
z_from_y_transformed = relu(y_transformed)
z_transformed_from_y = z.transform(g)
assert torch.allclose(z_from_y_transformed.tensor, z_transformed_from_y.tensor, atol=1e-5), g
Deeper Models: In deep learning we usually want to stack multiple layers to build a deep model. As long as each layer is equivariant and consecutive layers are compatible, the equivariance property is preserved by induction.
更深的模型:在深度学习中,我们通常希望堆叠多个层来构建一个深度模型。只要每一层是等变的并且连续的层是兼容的,等变性就可以通过归纳法保持。
The compatibility of two consecutive layers requires the output type of the first layer to be equal to the input type of the second layer.
两个连续层的兼容性要求第一层的输出类型等于第二层的输入类型。
In case we feed an input with the wrong type to a module, an error is raised:
如果我们向模块提供了错误类型的输入,就会引发错误:
[70]:
layer1 = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=3)
layer2 = nn.ReLU(feat_type_in) # the input type of the ReLU should be the output type of the convolution
x = feat_type_in(torch.randn(3, 1, 7, 7))
try:
y = layer2(layer1(x))
except AssertionError as e:
print(e)
Error! the type of the input does not match the input type of this module
Simple deeper architectures can be built using a SequentialModule:
可以使用 SequentialModule 构建简单的更深层架构:
[71]:
feat_type_in = nn.FieldType(r2_act, [r2_act.trivial_repr])
feat_type_hid = nn.FieldType(r2_act, 8*[r2_act.regular_repr])
feat_type_out = nn.FieldType(r2_act, 2*[r2_act.regular_repr])
model = nn.SequentialModule(
nn.R2Conv(feat_type_in, feat_type_hid, kernel_size=3),
nn.InnerBatchNorm(feat_type_hid),
nn.ReLU(feat_type_hid, inplace=True),
nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=3),
nn.InnerBatchNorm(feat_type_hid),
nn.ReLU(feat_type_hid, inplace=True),
nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=3),
).eval()
As every layer is equivariant and consecutive layers are compatible, the whole model is equivariant.
由于每一层都是等变的且连续的层是兼容的,整个模型是等变的。
[72]:
x = torch.randn(1, 1, 17, 17)
x = feat_type_in(x)
y = model(x)
# for each group element
for g in G.elements:
x_transformed = x.transform(g)
y_from_x_transformed = model(x_transformed)
y_transformed_from_x = y.transform(g)
assert torch.allclose(y_from_x_transformed.tensor, y_transformed_from_x.tensor, atol=1e-5), g
Invariant Pooling Layer: Usually, at the end of the model we want to produce a single feature vector to use for classification. To do so, it is common to pool over the spatial dimensions, e.g. via average pooling.
不变池化层:通常,在模型的末尾我们希望生成一个单一的特征向量用于分类。为此,通常会在空间维度上进行池化,例如通过平均池化。
This produces (approximatively) translation-invariant feature vectors.
这会产生(近似)平移不变的特征向量。
[73]:
# average pooling with window size 11
avgpool = nn.PointwiseAvgPool(feat_type_out, 11)
y = avgpool(model(x))
print(y.shape)
torch.Size([1, 8, 1, 1])
In our case, the feature vectors feat_type_out
(here two
在我们的例子中,与每个点 feat_type_out
进行变换(这里是两个
[74]:
for g in G.elements:
print(f'rotation by {g}:', y.transform(g).tensor[0, ...].detach().numpy().squeeze())
rotation by 0[2pi/4]: [0.508 0.562 0.566 0.59 0.227 0.227 0.224 0.234]
rotation by 1[2pi/4]: [0.59 0.508 0.562 0.566 0.234 0.227 0.227 0.224]
rotation by 2[2pi/4]: [0.566 0.59 0.508 0.562 0.224 0.234 0.227 0.227]
rotation by 3[2pi/4]: [0.562 0.566 0.59 0.508 0.227 0.224 0.234 0.227]
Many learning tasks require to build models which are invariant under rotations. We can compute invariant features from the output of the model using an invariant map. For instance, we can take the maximum value within each regular field. We do so using nn.GroupPooling
:
许多学习任务需要构建在旋转下不变的模型。我们可以使用不变映射从模型的输出中计算不变特征。例如,我们可以在每个常规字段中取最大值。我们这样做是使用 nn.GroupPooling
:
[75]:
invariant_map = nn.GroupPooling(feat_type_out)
y = invariant_map(avgpool(model(x)))
for g in G.elements:
print(f'rotation by {g}:', y.transform(g).tensor[0, ...].detach().numpy().squeeze())
rotation by 0[2pi/4]: [0.59 0.234]
rotation by 1[2pi/4]: [0.59 0.234]
rotation by 2[2pi/4]: [0.59 0.234]
rotation by 3[2pi/4]: [0.59 0.234]
[76]:
# for each group element
for g in G.elements:
# rotated the input image
x_transformed = x.transform(g)
y_from_x_transformed = invariant_map(avgpool(model(x_transformed)))
y_transformed_from_x = y # no .transform(g) needed since y should be invariant!
# check that the output did not change
# note that here we are not rotating the original output y as before
assert torch.allclose(y_from_x_transformed.tensor, y_transformed_from_x.tensor, atol=1e-6), g
2.3 Steerable CNN with infinite group ¶
2.3 可控卷积神经网络与无限群 ¶
We can now repeat the same constructions with
我们现在可以重复相同的构造,其中
[77]:
# use N=-1 to indicate all rotations
r2_act = gspaces.rot2dOnR2(N=-1)
r2_act
[77]:
SO(2)_on_R2[(None, -1)]
[78]:
G = r2_act.fibergroup
G
[78]:
SO(2)
[79]:
# For simplicity we take a single-channel gray-scale image in input and we output a single-channel gray-scale image, i.e. we use scalar fields in input and output
feat_type_in = nn.FieldType(r2_act, [G.trivial_representation])
feat_type_out = nn.FieldType(r2_act, [G.trivial_representation])
As intermidiate feature types, we want to use again the regular representation. Because
作为中间特征类型,我们想再次使用常规表示。因为
rho = G.bl_regular_representation(2)
To apply a non-linearity, e.g. ELU, we can use the Inverse Fourier Transform to sample the function, apply the non-linearity and, finally, compute the Fourier Transform to recover the coeffients. Because
要应用非线性,例如 ELU,我们可以使用逆傅里叶变换对函数进行采样,应用非线性,最后计算傅里叶变换以恢复系数。因为
Fortunately, the class nn.FourierELU
takes care of most of these details. We can just specify which irreps
to consider (G.bl_irreps(2)
returns the list of irreps up to frequency 2
), the number of channels
(i.e. copies of the regular representation) and the number N
of elements of
幸运的是,类 nn.FourierELU
处理了大部分这些细节。我们只需指定要考虑的 irreps
( G.bl_irreps(2)
返回频率 2
以内的不可约表示列表)、 channels
的数量(即常规表示的副本)以及 N
的元素数量
[80]:
nonlinearity = nn.FourierELU(r2_act, 16, irreps=G.bl_irreps(2), N=12)
# we do not need to pre-define the feature type: FourierELU will create it internally and we can just access it as
feat_type_hid = nonlinearity.in_type
# note also the its input and output types are the same
assert nonlinearity.in_type == nonlinearity.out_type
Let’s build a simple
让我们构建一个简单的
[81]:
equivariant_so2_model = nn.SequentialModule(
nn.R2Conv(feat_type_in, feat_type_hid, kernel_size=7),
nn.IIDBatchNorm2d(feat_type_hid),
nonlinearity,
nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size=7),
nn.IIDBatchNorm2d(feat_type_hid),
nonlinearity,
nn.R2Conv(feat_type_hid, feat_type_out, kernel_size=7),
).eval()
and check its equivariance to a few elements of
并检查其对
[82]:
x = torch.randn(1, 1, 23, 23)
x = feat_type_in(x)
y = equivariant_so2_model(x)
# check equivariance to N=16 rotations
N = 16
try:
for i in range(N):
g = G.element(i*2*np.pi/N)
x_transformed = x.transform(g)
y_from_x_transformed = equivariant_so2_model(x_transformed)
y_transformed_from_x = y.transform(g)
assert torch.allclose(y_from_x_transformed.tensor, y_transformed_from_x.tensor, atol=1e-3), g
except:
print('Error! The model is not equivariant!')
Error! The model is not equivariant!
QUESTION 10¶ 问题 10 ¶
The model is not perfectly equivariant to
该模型对
ANSWER 10¶ 答案 10 ¶
The
While the model can not be perfectly equivariant, we can compare it with a conventional CNN baseline. Let’s build a CNN similar to our equivariant model but which is not constrained to be equivariant:
虽然模型不能完全等变,但我们可以将其与传统的 CNN 基线进行比较。让我们构建一个与我们的等变模型相似但不受等变约束的 CNN:
[83]:
conventional_model = torch.nn.Sequential(
torch.nn.Conv2d(feat_type_in.size, feat_type_hid.size, kernel_size=7),
torch.nn.BatchNorm2d(feat_type_hid.size),
torch.nn.ELU(),
torch.nn.Conv2d(feat_type_hid.size, feat_type_hid.size, kernel_size=7),
torch.nn.BatchNorm2d(feat_type_hid.size),
torch.nn.ELU(),
torch.nn.Conv2d(feat_type_hid.size, feat_type_out.size, kernel_size=7),
).eval()
To compare the two models, we compute their equivariance error for a few elements of
为了比较这两个模型,我们计算它们在
Note that this is a form of relative error. Let’s now compute the equivariance error of the two models:
请注意,这是一种相对误差。现在让我们计算两个模型的等变误差:
[84]:
# let's generate a random image of shape W x W
W = 37
x = torch.randn(1, 1, W, W)
# Because a rotation by an angle smaller than 90 degrees moves pixels outsize the image, we mask out all pixels outside the central disk
# We need to do this both for the input and the output
def build_mask(W):
center_mask = np.zeros((2, W, W))
center_mask[1, :, :] = np.arange(0, W) - W // 2
center_mask[0, :, :] = np.arange(0, W) - W // 2
center_mask[0, :, :] = center_mask[0, :, :].T
center_mask = center_mask[0, :, :] ** 2 + center_mask[1, :, :] ** 2 < .9*(W // 2) ** 2
center_mask = torch.tensor(center_mask.reshape(1, 1, W, W), dtype=torch.float)
return center_mask
# create the mask for the input
input_center_mask = build_mask(W)
# mask the input image
x = x * input_center_mask
x = feat_type_in(x)
# compute the output of both models
y_equivariant = equivariant_so2_model(x)
y_conventional = feat_type_out(conventional_model(x.tensor))
# create the mask for the output images
output_center_mask = build_mask(y_equivariant.shape[-1])
# We evaluate the equivariance error on N=100 rotations
N = 100
error_equivariant = []
error_conventional = []
# for each of the N rotations
for i in range(N+1):
g = G.element(i / N * 2*np.pi)
# rotate the input
x_transformed = x.transform(g)
x_transformed.tensor *= input_center_mask
# F(g.X) feed the transformed images in both models
y_from_x_transformed_equivariant = equivariant_so2_model(x_transformed).tensor
y_from_x_transformed_conventional = conventional_model(x_transformed.tensor)
# g.F(x) transform the output of both models
y_transformed_from_x_equivariant = y_equivariant.transform(g)
y_transformed_from_x_conventional = y_conventional.transform(g)
# mask all the outputs
y_from_x_transformed_equivariant = y_from_x_transformed_equivariant * output_center_mask
y_from_x_transformed_conventional = y_from_x_transformed_conventional * output_center_mask
y_transformed_from_x_equivariant = y_transformed_from_x_equivariant.tensor * output_center_mask
y_transformed_from_x_conventional = y_transformed_from_x_conventional.tensor * output_center_mask
# compute the relative error of both models
rel_error_equivariant = torch.norm(y_from_x_transformed_equivariant - y_transformed_from_x_equivariant).item() / torch.norm(y_equivariant.tensor).item()
rel_error_conventional = torch.norm(y_from_x_transformed_conventional - y_transformed_from_x_conventional).item() / torch.norm(y_conventional.tensor).item()
error_equivariant.append(rel_error_equivariant)
error_conventional.append(rel_error_conventional)
# plot the error of both models as a function of the rotation angle theta
fig, ax = plt.subplots(figsize=(10, 6))
xs = [i*2*np.pi / N for i in range(N+1)]
plt.plot(xs, error_equivariant, label='SO(2)-Steerable CNN')
plt.plot(xs, error_conventional, label='Conventional CNN')
plt.title('Equivariant vs Conventional CNNs', fontsize=20)
plt.xlabel(r'$g = r_\theta$', fontsize=20)
plt.ylabel('Equivariance Error', fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.legend(fontsize=20)
plt.show()
3. Build and Train Steerable CNNs¶
3. 构建和训练可控卷积神经网络 ¶
Finally, we will proceed with implementing a Steerable CNN and train it on rotated MNIST.
最后,我们将继续实施可控卷积神经网络,并在旋转的 MNIST 上进行训练。
Dataset¶ 数据集
We will evaluate the model on the rotated MNIST dataset. First, we download the (non-rotated) MNIST 12k data:
我们将评估模型在旋转的 MNIST 数据集上的表现。首先,我们下载(未旋转的)MNIST 12k 数据:
[85]:
# download the dataset
!wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip
# uncompress the zip file
!unzip -n mnist.zip -d mnist
File ‘mnist.zip’ already there; not retrieving.
/bin/bash: unzip: command not found
Then, we build the dataset and some utility functions:
然后,我们构建数据集和一些实用函数:
[86]:
from torch.utils.data import Dataset
from torchvision.transforms import RandomRotation
from torchvision.transforms import Pad
from torchvision.transforms import Resize
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from tqdm.auto import tqdm
from PIL import Image
device = 'cuda' if torch.cuda.is_available() else 'cpu'
[87]:
class MnistDataset(Dataset):
def __init__(self, mode, rotated: bool = True):
assert mode in ['train', 'test']
if mode == "train":
file = "mnist/mnist_train.amat"
else:
file = "mnist/mnist_test.amat"
data = np.loadtxt(file)
images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
# images are padded to have shape 29x29.
# this allows to use odd-size filters with stride 2 when downsampling a feature map in the model
pad = Pad((0, 0, 1, 1), fill=0)
# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
resize1 = Resize(87) # to upsample
resize2 = Resize(29) # to downsample
totensor = ToTensor()
if rotated:
self.images = torch.empty((images.shape[0], 1, 29, 29))
for i in tqdm(range(images.shape[0]), leave=False):
img = images[i]
img = Image.fromarray(img, mode='F')
r = (np.random.rand() * 360.)
self.images[i] = totensor(resize2(resize1(pad(img)).rotate(r, Image.BILINEAR))).reshape(1, 29, 29)
else:
self.images = torch.zeros((images.shape[0], 1, 29, 29))
self.images[:, :, :28, :28] = torch.tensor(images).reshape(-1, 1, 28, 28)
self.labels = data[:, -1].astype(np.int64)
self.num_samples = len(self.labels)
def __getitem__(self, index):
image, label = self.images[index], self.labels[index]
return image, label
def __len__(self):
return len(self.labels)
[88]:
# Set the random seed for reproducibility
np.random.seed(42)
# build the rotated training and test datasets
mnist_train = MnistDataset(mode='train', rotated=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64)
mnist_test = MnistDataset(mode='test', rotated=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64)
# for testing purpose, we also build a version of the test set with *non*-rotated digits
raw_mnist_test = MnistDataset(mode='test', rotated=False)
equivariant architecture¶
等变架构
We now build an
我们现在构建一个
Because the inputs are still gray-scale images, the input type of the model is again a scalar field. In the intermidiate layers, we will use regular fields, such that the models are equivalent to group-equivariant convolutional neural networks (GCNNs).
由于输入仍然是灰度图像,模型的输入类型再次是标量场。在中间层中,我们将使用常规场,使得模型等同于群等变卷积神经网络(GCNNs)。
The final classification is performed by a fully connected layer.
最终分类由一个全连接层执行。
[89]:
class SO2SteerableCNN(torch.nn.Module):
def __init__(self, n_classes=10):
super(SO2SteerableCNN, self).__init__()
# the model is equivariant under all planar rotations
self.r2_act = gspaces.rot2dOnR2(N=-1)
# the group SO(2)
self.G: SO2 = self.r2_act.fibergroup
# the input image is a scalar field, corresponding to the trivial representation
in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
# we store the input type for wrapping the images into a geometric tensor during the forward pass
self.input_type = in_type
# We need to mask the input image since the corners are moved outside the grid under rotations
self.mask = nn.MaskModule(in_type, 29, margin=1)
# convolution 1
# first we build the non-linear layer, which also constructs the right feature type
# we choose 8 feature fields, each transforming under the regular representation of SO(2) up to frequency 3
# When taking the ELU non-linearity, we sample the feature fields on N=16 points
activation1 = nn.FourierELU(self.r2_act, 8, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation1.in_type
self.block1 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
nn.IIDBatchNorm2d(out_type),
activation1,
)
# convolution 2
# the old output type is the input type to the next layer
in_type = self.block1.out_type
# the output type of the second convolution layer are 16 regular feature fields
activation2 = nn.FourierELU(self.r2_act, 16, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation2.in_type
self.block2 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation2
)
# to reduce the downsampling artifacts, we use a Gaussian smoothing filter
self.pool1 = nn.SequentialModule(
nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
)
# convolution 3
# the old output type is the input type to the next layer
in_type = self.block2.out_type
# the output type of the third convolution layer are 32 regular feature fields
activation3 = nn.FourierELU(self.r2_act, 32, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation3.in_type
self.block3 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation3
)
# convolution 4
# the old output type is the input type to the next layer
in_type = self.block3.out_type
# the output type of the fourth convolution layer are 64 regular feature fields
activation4 = nn.FourierELU(self.r2_act, 32, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation4.in_type
self.block4 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation4
)
self.pool2 = nn.SequentialModule(
nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
)
# convolution 5
# the old output type is the input type to the next layer
in_type = self.block4.out_type
# the output type of the fifth convolution layer are 96 regular feature fields
activation5 = nn.FourierELU(self.r2_act, 64, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation5.in_type
self.block5 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation5
)
# convolution 6
# the old output type is the input type to the next layer
in_type = self.block5.out_type
# the output type of the sixth convolution layer are 64 regular feature fields
activation6 = nn.FourierELU(self.r2_act, 64, irreps=G.bl_irreps(3), N=16, inplace=True)
out_type = activation6.in_type
self.block6 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
nn.IIDBatchNorm2d(out_type),
activation6
)
self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
# number of output invariant channels
c = 64
# last 1x1 convolution layer, which maps the regular fields to c=64 invariant scalar fields
# this is essential to provide *invariant* features in the final classification layer
output_invariant_type = nn.FieldType(self.r2_act, c*[self.r2_act.trivial_repr])
self.invariant_map = nn.R2Conv(out_type, output_invariant_type, kernel_size=1, bias=False)
# Fully Connected classifier
self.fully_net = torch.nn.Sequential(
torch.nn.BatchNorm1d(c),
torch.nn.ELU(inplace=True),
torch.nn.Linear(c, n_classes),
)
def forward(self, input: torch.Tensor):
# wrap the input tensor in a GeometricTensor
# (associate it with the input type)
x = self.input_type(input)
# mask out the corners of the input image
x = self.mask(x)
# apply each equivariant block
# Each layer has an input and an output type
# A layer takes a GeometricTensor in input.
# This tensor needs to be associated with the same representation of the layer's input type
#
# Each layer outputs a new GeometricTensor, associated with the layer's output type.
# As a result, consecutive layers need to have matching input/output types
x = self.block1(x)
x = self.block2(x)
x = self.pool1(x)
x = self.block3(x)
x = self.block4(x)
x = self.pool2(x)
x = self.block5(x)
x = self.block6(x)
# pool over the spatial dimensions
x = self.pool3(x)
# extract invariant features
x = self.invariant_map(x)
# unwrap the output GeometricTensor
# (take the Pytorch tensor and discard the associated representation)
x = x.tensor
# classify with the final fully connected layer
x = self.fully_net(x.reshape(x.shape[0], -1))
return x
Equivariance Test before training¶
训练前的等变性测试 ¶
Let’s instantiate the model:
让我们实例化模型:
[90]:
model = SO2SteerableCNN().to(device)
The model is now randomly initialized. Therefore, we do not expect it to produce the right class probabilities.
该模型现在是随机初始化的。因此,我们不期望它产生正确的类别概率。
However, the model should still produce the same output for rotated versions of the same image. This is true for rotations by multiples of
然而,该模型仍应为同一图像的旋转版本生成相同的输出。这对于
Let’s test it on a random test image: we feed
让我们在一个随机测试图像上进行测试:我们输入测试集中第一张图像的
[91]:
def test_model_single_image(model: torch.nn.Module, x: torch.Tensor, N: int = 8):
np.set_printoptions(linewidth=10000)
x = Image.fromarray(x.cpu().numpy()[0], mode='F')
# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
resize1 = Resize(87) # to upsample
resize2 = Resize(29) # to downsample
totensor = ToTensor()
x = resize1(x)
# evaluate the `model` on N rotated versions of the input image `x`
model.eval()
print()
print('##########################################################################################')
header = 'angle | ' + ' '.join(["{:5d}".format(d) for d in range(10)])
print(header)
with torch.no_grad():
for r in range(N):
x_transformed = totensor(resize2(x.rotate(r*360./N, Image.BILINEAR))).reshape(1, 1, 29, 29)
x_transformed = x_transformed.to(device)
y = model(x_transformed)
y = y.to('cpu').numpy().squeeze()
angle = r * 360. / N
print("{:6.1f} : {}".format(angle, y))
print('##########################################################################################')
print()
[92]:
# retrieve the first image from the test set
x, y = next(iter(raw_mnist_test))
# evaluate the model
test_model_single_image(model, x, N=20)
##########################################################################################
angle | 0 1 2 3 4 5 6 7 8 9
0.0 : [ 0.106 1.08 -1.623 -0.825 1.574 -0.265 -0.12 1.242 0.219 1.639]
18.0 : [ 0.093 1.087 -1.643 -0.828 1.574 -0.27 -0.117 1.255 0.213 1.631]
36.0 : [ 0.094 1.072 -1.632 -0.833 1.562 -0.272 -0.121 1.257 0.194 1.602]
54.0 : [ 0.091 1.068 -1.615 -0.833 1.568 -0.262 -0.131 1.236 0.209 1.59 ]
72.0 : [ 0.108 1.081 -1.623 -0.829 1.573 -0.261 -0.129 1.227 0.231 1.628]
90.0 : [ 0.106 1.08 -1.623 -0.825 1.574 -0.265 -0.12 1.242 0.219 1.639]
108.0 : [ 0.093 1.087 -1.643 -0.828 1.574 -0.27 -0.117 1.255 0.213 1.631]
126.0 : [ 0.094 1.072 -1.632 -0.833 1.562 -0.272 -0.121 1.257 0.194 1.602]
144.0 : [ 0.091 1.068 -1.615 -0.833 1.568 -0.262 -0.131 1.236 0.209 1.59 ]
162.0 : [ 0.108 1.081 -1.623 -0.829 1.573 -0.261 -0.129 1.227 0.231 1.628]
180.0 : [ 0.106 1.08 -1.623 -0.825 1.574 -0.265 -0.12 1.242 0.219 1.639]
198.0 : [ 0.093 1.087 -1.643 -0.828 1.574 -0.27 -0.117 1.255 0.213 1.631]
216.0 : [ 0.094 1.072 -1.632 -0.833 1.562 -0.272 -0.121 1.257 0.194 1.602]
234.0 : [ 0.091 1.068 -1.615 -0.833 1.568 -0.262 -0.131 1.236 0.209 1.59 ]
252.0 : [ 0.108 1.081 -1.623 -0.829 1.573 -0.261 -0.129 1.227 0.231 1.628]
270.0 : [ 0.106 1.08 -1.623 -0.825 1.574 -0.265 -0.12 1.242 0.219 1.639]
288.0 : [ 0.093 1.087 -1.643 -0.828 1.574 -0.27 -0.117 1.255 0.213 1.631]
306.0 : [ 0.094 1.072 -1.632 -0.833 1.562 -0.272 -0.121 1.257 0.194 1.602]
324.0 : [ 0.091 1.068 -1.615 -0.833 1.568 -0.262 -0.131 1.236 0.209 1.59 ]
342.0 : [ 0.108 1.081 -1.623 -0.829 1.573 -0.261 -0.129 1.227 0.231 1.628]
##########################################################################################
The output of the model is already almost invariant but we observe small fluctuations in the outputs. This is the effect of the discretization artifacts (e.g. the pixel grid can not be perfectly rotated by any angle without interpolation) and can not be completely removed.
模型的输出已经几乎不变,但我们观察到输出中有小的波动。这是离散化伪影的影响(例如,像素网格不能在没有插值的情况下完美地旋转任意角度)且无法完全消除。
Training the model¶ 训练模型
Let’s train the model now. The procedure is the same used to train a normal PyTorch architecture:
现在让我们训练模型。该过程与训练普通的 PyTorch 架构相同:
[93]:
# build the training and test function
def test(model: torch.nn.Module):
# test over the full rotated test set
total = 0
correct = 0
with torch.no_grad():
model.eval()
for i, (x, t) in enumerate(test_loader):
x = x.to(device)
t = t.to(device)
y = model(x)
_, prediction = torch.max(y.data, 1)
total += t.shape[0]
correct += (prediction == t).sum().item()
return correct/total*100.
def train(model: torch.nn.Module, lr=1e-4, wd=1e-4, checkpoint_path: str = None):
if checkpoint_path is not None:
checkpoint_path = os.path.join(CHECKPOINT_PATH, checkpoint_path)
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
model.eval()
return
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
for epoch in tqdm(range(21)):
model.train()
for i, (x, t) in enumerate(train_loader):
optimizer.zero_grad()
x = x.to(device)
t = t.to(device)
y = model(x)
loss = loss_function(y, t)
loss.backward()
optimizer.step()
del x, y, t, loss
if epoch % 10 == 0:
accuracy = test(model)
print(f"epoch {epoch} | test accuracy: {accuracy}")
if checkpoint_path is not None:
torch.save(model.state_dict(), checkpoint_path)
Finally, train the
最后,训练
[94]:
# set the seed manually for reproducibility
torch.manual_seed(42)
model = SO2SteerableCNN().to(device)
train(model, checkpoint_path="steerable_so2-pretrained.ckpt")
accuracy = test(model)
print(f"Test accuracy: {accuracy}")
Test accuracy: 94.98400000000001
[95]:
def test_model_rotations(model: torch.nn.Module, N: int = 24, M: int = 2000, checkpoint_path: str = None):
# evaluate the `model` on N rotated versions of the first M images in the test set
if checkpoint_path is not None:
checkpoint_path = os.path.join(CHECKPOINT_PATH, checkpoint_path)
if checkpoint_path is not None and os.path.isfile(checkpoint_path):
accuracies = np.load(checkpoint_path)
return accuracies.tolist()
model.eval()
# to reduce interpolation artifacts (e.g. when testing the model on rotated images),
# we upsample an image by a factor of 3, rotate it and finally downsample it again
resize1 = Resize(87) # to upsample
resize2 = Resize(29) # to downsample
totensor = ToTensor()
accuracies = []
with torch.no_grad():
model.eval()
for r in tqdm(range(N)):
total = 0
correct = 0
for i in range(M):
x, t = raw_mnist_test[i]
x = Image.fromarray(x.numpy()[0], mode='F')
x = totensor(resize2(resize1(x).rotate(r*360./N, Image.BILINEAR))).reshape(1, 1, 29, 29).to(device)
x = x.to(device)
y = model(x)
_, prediction = torch.max(y.data, 1)
total += 1
correct += (prediction == t).sum().item()
accuracies.append(correct/total*100.)
if checkpoint_path is not None:
np.save(checkpoint_path, np.array(accuracies))
return accuracies
[96]:
accs_so2 = test_model_rotations(model, 16, 10000, checkpoint_path="steerable_so2-accuracies.npy")
[97]:
# plot the accuracy of as a function of the rotation angle theta applied to the test set
fig, ax = plt.subplots(figsize=(10, 6))
N = 16
xs = [i*2*np.pi / N for i in range(N+1)]
plt.plot(xs, accs_so2 + [accs_so2[0]])
plt.title('SO(2)-Steerable CNN', fontsize=20)
plt.xlabel(r'Test rotation $\theta \in [0, 2\pi)$', fontsize=20)
plt.ylabel('Accuracy', fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.show()
Even after training, the model is not perfectly
即使经过训练,模型也不是完全
equivariant architecture¶
等变架构
For comparison, let’s build a similar architecture equivariant only to
为了比较,让我们构建一个仅对
[98]:
class CNSteerableCNN(torch.nn.Module):
def __init__(self, n_classes=10):
super(CNSteerableCNN, self).__init__()
# the model is equivariant to rotations by multiples of 2pi/N
self.r2_act = gspaces.rot2dOnR2(N=4)
# the input image is a scalar field, corresponding to the trivial representation
in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])
# we store the input type for wrapping the images into a geometric tensor during the forward pass
self.input_type = in_type
# We need to mask the input image since the corners are moved outside the grid under rotations
self.mask = nn.MaskModule(in_type, 29, margin=1)
# convolution 1
# first we build the non-linear layer, which also constructs the right feature type
# we choose 8 feature fields, each transforming under the regular representation of C_4
activation1 = nn.ELU(nn.FieldType(self.r2_act, 8*[self.r2_act.regular_repr]), inplace=True)
out_type = activation1.in_type
self.block1 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
nn.IIDBatchNorm2d(out_type),
activation1,
)
# convolution 2
# the old output type is the input type to the next layer
in_type = self.block1.out_type
# the output type of the second convolution layer are 16 regular feature fields
activation2 = nn.ELU(nn.FieldType(self.r2_act, 16*[self.r2_act.regular_repr]), inplace=True)
out_type = activation2.in_type
self.block2 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation2
)
self.pool1 = nn.SequentialModule(
nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
)
# convolution 3
# the old output type is the input type to the next layer
in_type = self.block2.out_type
# the output type of the third convolution layer are 32 regular feature fields
activation3 = nn.ELU(nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr]), inplace=True)
out_type = activation3.in_type
self.block3 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation3
)
# convolution 4
# the old output type is the input type to the next layer
in_type = self.block3.out_type
# the output type of the fourth convolution layer are 32 regular feature fields
activation4 = nn.ELU(nn.FieldType(self.r2_act, 32*[self.r2_act.regular_repr]), inplace=True)
out_type = activation4.in_type
self.block4 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation4
)
self.pool2 = nn.SequentialModule(
nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
)
# convolution 5
# the old output type is the input type to the next layer
in_type = self.block4.out_type
# the output type of the fifth convolution layer are 64 regular feature fields
activation5 = nn.ELU(nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr]), inplace=True)
out_type = activation5.in_type
self.block5 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
nn.IIDBatchNorm2d(out_type),
activation5
)
# convolution 6
# the old output type is the input type to the next layer
in_type = self.block5.out_type
# the output type of the sixth convolution layer are 64 regular feature fields
activation6 = nn.ELU(nn.FieldType(self.r2_act, 64*[self.r2_act.regular_repr]), inplace=True)
out_type = activation6.in_type
self.block6 = nn.SequentialModule(
nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
nn.IIDBatchNorm2d(out_type),
activation6
)
self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)
# number of output invariant channels
c = 64
output_invariant_type = nn.FieldType(self.r2_act, c*[self.r2_act.trivial_repr])
self.invariant_map = nn.R2Conv(out_type, output_invariant_type, kernel_size=1, bias=False)
# Fully Connected classifier
self.fully_net = torch.nn.Sequential(
torch.nn.BatchNorm1d(c),
torch.nn.ELU(inplace=True),
torch.nn.Linear(c, n_classes),
)
def forward(self, input: torch.Tensor):
# wrap the input tensor in a GeometricTensor
# (associate it with the input type)
x = self.input_type(input)
# mask out the corners of the input image
x = self.mask(x)
# apply each equivariant block
# Each layer has an input and an output type
# A layer takes a GeometricTensor in input.
# This tensor needs to be associated with the same representation of the layer's input type
#
# Each layer outputs a new GeometricTensor, associated with the layer's output type.
# As a result, consecutive layers need to have matching input/output types
x = self.block1(x)
x = self.block2(x)
x = self.pool1(x)
x = self.block3(x)
x = self.block4(x)
x = self.pool2(x)
x = self.block5(x)
x = self.block6(x)
# pool over the spatial dimensions
x = self.pool3(x)
# extract invariant features
x = self.invariant_map(x)
# unwrap the output GeometricTensor
# (take the Pytorch tensor and discard the associated representation)
x = x.tensor
# classify with the final fully connected layer
x = self.fully_net(x.reshape(x.shape[0], -1))
return x
Instantiate and train the
实例化并训练
[99]:
torch.manual_seed(42)
model_c4 = CNSteerableCNN().to(device)
train(model_c4, checkpoint_path="steerable_c4-pretrained.ckpt")
accuracy = test(model_c4)
print(f"Test accuracy: {accuracy}")
accs_c4 = test_model_rotations(model_c4, 16, 10000, checkpoint_path="steerable_c4-accuracies.npy")
Test accuracy: 93.84
Finally, let’s compare the performance of both models on the rotated test sets:
最后,让我们比较两个模型在旋转测试集上的表现:
[100]:
# plot the accuracy of as a function of the rotation angle theta applied to the test set
fig, ax = plt.subplots(figsize=(10, 6))
N=16
xs = [i*2*np.pi / N for i in range(N+1)]
plt.plot(xs, accs_so2 + [accs_so2[0]], label=r'$SO(2)$-Steerable CNN')
plt.plot(xs, accs_c4 + [accs_c4[0]], label=r'$C_4$-Steerable CNN')
plt.title(r'$C_4$ vs $SO(2)$ Steerable CNNs', fontsize=20)
plt.xlabel(r'Test rotation ($\theta \in [0, 2\pi)$)', fontsize=20)
plt.ylabel('Accuracy', fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.legend(fontsize=20)
plt.show()
While perfect equivariance to
虽然由于离散化无法实现对
Conclusion¶ 结论 ¶
In this tutorial, you first leart about group representation theory and the Fourier Transform over compact groups. These are the mathematical tools used to formalize Steerable CNNs.
在本教程中,您首先学习了群表示理论和紧致群上的傅里叶变换。这些是用于形式化 Steerable CNNs 的数学工具。
In the second part of this tutorial, you learnt about steerable feature fields and steerable CNNs. In particular, the previously defined Fourier transform allowed us to build a steerable CNN which is equivalent to a Group-Convolutional Neural Network (GCNN) equivariant to translations and the continuous group
在本教程的第二部分中,您学习了可控特征场和可控卷积神经网络。特别是,先前定义的傅里叶变换使我们能够构建一个可控卷积神经网络,该网络等价于一个对平移和连续旋转群
In our steerable CNNs, we mostly leveraged the regular representation of the group
在我们的可控 CNN 中,我们主要利用了群