这是用户在 2024-6-16 22:50 为 https://zhuanlan.zhihu.com/p/703560963 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
Dummy Diff Raster 最直白的可微渲染

Dummy Diff Raster 最直白的可微渲染

6 人赞同了该文章
发布于 2024-06-15 15:40・IP 属地安徽 ,编辑于 2024-06-15 18:03・IP 属地安徽

本文用一个最简单的“反色”操作,介绍可微渲染器的工作原理和最简单的CUDA代码实现,并利用pybind11封装一个python接口,再用torch.autograd.Function封装成一个可微的torch.nn.Module,从而可以直接在pytorch中使用。

原理

不是很严谨地说,函数的可微分意味着我们可以在任意一个状态点(x,y)计算出其导数dydx\frac{dy}{dx}(多元函数则是当前点的梯度)。如果有了梯度,我们就可以使用梯度下降来完成最优化的过程。

假定我们初始值为x0x_0,优化的第i步,对应的函数值yi=f(xi)y_i=f(x_i),现在我们想要找到一个x˙\dot{x}将函数值无限靠近目标值yty_t,使用梯度下降的方法,我们可以得到当前函数的导数dydx\frac{dy}{dx},然后

  • 定义损失函数loss: dL=12(yiyt)2dL=\frac{1}{2}(y_i-y_t)^2
  • 计算梯度gradient: dLdxi=dLdyidyidxi\frac{dL}{dx_i}=\frac{dL}{dy_i}\frac{dy_i}{dx_i}
  • 定义学习率α\alpha
  • 梯度下降: xi+1=xiαdLdxix_{i+1}=x_i-\alpha\frac{dL}{dx_i}
  • 最终收敛到L0L\approx 0的时候,xx˙x\approx \dot{x}

渲染过程同样也可以看作一个函数y=f(x)y=f(x),其中x指的是场景定义,相机参数,光照参数等,而y指的是渲染出来的图像Iw,hI_{w,h}一共有w×hw\times h个像素,每一个像素都有RGB三个浮点数。

通常来说,渲染的过程是不可微的,我们很难直接找到一个梯度函数,让计算机拿这个梯度,对比当前的画面和期望的画面,直接把场景优化到我们希望的结果。这个过程原本其实是创作者们在自己的脑海中完成的。创作者脑海中先想象出来一个场景画面,然后打开建模软件开始建模,画贴图,直到最终建模软件渲染出来的结果和自己脑海中想象的结果靠近。

但最近逐渐开始流行一些渲染方法让这个函数变得可微,这样理论上,我们就可以自动化上述过程,只要创作者把画面画出来,计算机可以自动优化出希望的场景模型,相机,光照等,一步到位,不需要再有漫长的调整工作。这一类渲染方法称为可微渲染(Differential Rendering)。

本文为了展示可微渲染的流程,因此选用了一种非常trivial的“渲染方法”,或者更像是一个图像处理器,它的输入是一张图片,输出是这张图片的反色。此时

  • x=Isx=I_s,一张w×hw\times h的原图
  • y=Ity=I_t,一张w×hw\times h的图片,其每一个像素It[i,j]=1.0Is[i,j]I_t[i,j]=1.0-I_s[i,j]
  • 这个“渲染”显然是可微的dIt[i,j]dIs[i,j]=1\frac{d I_t[i,j]}{d I_s[i,j]}=-1

如图所示,这是我们初始情况,我拿浙大的校徽作为优化的目标IrefI_{ref}

我们的初始图片Is0I_{s0}是中间那一张纯黑的图片,经过“渲染”(在这里其实是反色),会得到一张纯白的图片I_{t0}。现在我们的目标是,让计算机知道如何通过调整I_s的值,让最终目标I_t接近I_{ref}

为此我们定义一个损失函数,它是当前目标图片和参考图片像素之间差的平方。

\mathcal{L}=\frac{1}{2}\sum\limits_{i,j}\limits^{} (I_t[i,j]-I_{ref}[i,j])^2

这当然是一个多元函数,它的梯度是全导数。对于loss的微小变化d\mathcal{L}和每一个像素之间的微小变化 dI_t[i,j]之间的关系为

d\mathcal{L}=\frac{\partial \mathcal{L}}{\partial I_t[i,j]}dI_t[i,j]

上文所述,感谢我们的“可微渲染器”,我们可以得到dI_t[i,j]=\frac{dI_t[i,j]}{dI_s[i,j]}dI_s[i,j]

这样我们就可以梯度传递,建立loss的微小变化和x微小变化之间的关系:

\frac{\partial \mathcal{L}}{\partial I_s[i,j]}=\frac{\partial \mathcal{L}}{\partial I_t[i,j]}\frac{dI_t[i,j]}{dI_s[i,j]}

当前的“梯度”也被这样定义了,同样根据上面所介绍的“梯度下降”方法,就可以进行优化:

最终得到这样的结果

可以看到我们的最终“渲染”出来的成品图已经成功达到效果了。

CUDA实现

首先我们定义一个类 DummyDiffRender,它的成员函数m_w,m_h用来记录图片分辨率,正向实现

__global__ void forward_kernel(int W, int H, float* d_spix, float* d_tpix) {
	int x = blockIdx.x * blockDim.x + threadIdx.x;
	int y = blockIdx.y * blockDim.y + threadIdx.y;
	if (x >= W || y >= H) return;
	int i = y * W + x;
	for (int c = 0; c < 3; ++c) {
		d_tpix[i + c * W * H] = 1.0f - d_spix[i + c * W * H];
	}
}

void DummyDiffRender::forward(float* d_source_pix, int h, int w, float* d_target_pix) noexcept {
	dim3 block(32, 32);
	dim3 grid((w + block.x - 1) / block.x, (h + block.y - 1) / block.y);
	forward_kernel<<<grid, block>>>(w, h, d_source_pix, d_target_pix);
}

反向过程

__global__ void backward_kernel(int W, int H, float* dL_dtpix, float* dL_dspix) {
	int x = blockIdx.x * blockDim.x + threadIdx.x;
	int y = blockIdx.y * blockDim.y + threadIdx.y;
	if (x >= W || y >= H) return;

	int i = y * W + x;
	for (int c = 0; c < 3; ++c) {
		// printf("dL_dtpix[%d] = %f\n", i + c * W * H, dL_dtpix[i + c * W * H]);
		dL_dspix[i + c * W * H] = -dL_dtpix[i + c * W * H];
	}
}

void DummyDiffRender::backward(float* d_dL_d_target_pix, float* d_dL_d_source_pix) noexcept {
	dim3 block(32, 32);
	dim3 grid((m_w + block.x - 1) / block.x, (m_h + block.y - 1) / block.y);
	backward_kernel<<<grid, block>>>(m_w, m_h, d_dL_d_target_pix, d_dL_d_source_pix);
}

然后我们使用pybind11将forward和backward方法绑定到python输出,然后利用pytorch中现成的梯度下降工具来进行优化

Binding:

...
pybind11::class_<DummyDiffRender>(m, "DummyDiffRenderApp")
    .def(pybind11::init<>())
    .def("forward", &DummyDiffRender::forward_py)
    .def("backward", &DummyDiffRender::backward_py);

pytorch封装

封装一个python的torch.autograd.Functionnn.Module

class _DummyDiffRender(torch.autograd.Function):
    @staticmethod 
    def forward(ctx, source_img, height, width, app):
        result_img = torch.zeros((3, height, width), dtype=torch.float32).cuda()
        app.forward(source_img.contiguous().data_ptr(),
                    height, width, 
                    result_img.contiguous().data_ptr())
        ctx.app = app
        ctx.height = height
        ctx.width = width
        
        return result_img
    
    @staticmethod
    def backward(ctx, dL_dtpix):
        app = ctx.app
        dL_dsource_img = torch.zeros((3, ctx.height, ctx.width), dtype=torch.float32).cuda()
        app.backward(dL_dtpix.contiguous().data_ptr(),
                     dL_dsource_img.contiguous().data_ptr())
        # print(dL_dsource_img)
        return dL_dsource_img, None, None, None 

class DummyDiffRender(nn.Module):
    def __init__(self):
        super(DummyDiffRender, self).__init__()
        self.app = DummyDiffRenderApp()
        
    def forward(self, source_img, height, width):
        return _DummyDiffRender.apply(source_img, height, width, self.app)

这样我们就可以如同任何pytorch module一样直接使用了

    dummy_diff_render = DummyDiffRender()
    height = 100
    width = 100

    source_img = torch.zeros((3, height, width), dtype=torch.float32).cuda()
    source_img[0, :, :] = 1 # set initial image to red

    # target image is blue+green = cyan
    target_img = dummy_diff_render.forward(source_img, height, width)
    target_img_np = target_img.cpu().detach().numpy().transpose(1, 2, 0)
    plt.imshow(target_img_np)
    plt.show()

    source_img = torch.zeros((3, height, width), dtype=torch.float32).cuda()
    source_img[1, :, :] = 1
    source_img.requires_grad = True

    N_ITER = 300
    N_LOG = 50

    optim = torch.optim.Adam([source_img], lr=0.01)

    for i in range(N_ITER):
        result_img = dummy_diff_render.forward(source_img, height, width)
        loss = torch.nn.functional.mse_loss(target_img, result_img)
        loss.backward()

        with torch.no_grad():
            optim.step()
            optim.zero_grad()
            if i % N_LOG == 0:
                print(f'Iter {i}, Loss {loss.item()}')
                result_img_np = result_img.cpu().detach().numpy().transpose(1, 2, 0)
                plt.imshow(result_img_np)
                plt.show()

Reference

  • [[Dummy Diff Render|feature.diff-render.dummy-diff-render]]

本文使用 WPL/s 发布 @GitHub

发布于 2024-06-15 15:40・IP 属地安徽 ,编辑于 2024-06-15 18:03・IP 属地安徽
欢迎参与讨论

还没有评论,发表第一个评论吧

文章被以下专栏收录

想来知乎工作?请发送邮件到 jobs@zhihu.com