在 pytorch 中加速矢量函数的导数

问题描述 投票:0回答:1

我需要计算

output
(比如150)中的矢量函数
batch x nvarout
相对于
x
(比如2)中的输入
batch x nvarin
的一阶和二阶导数。 我设法使用以下代码做到这一点:

def continuous_diff(x, y):
    torch.set_grad_enabled(True)
    x.requires_grad_(True)
    # x in [N,nvarin]
    # y in [N,nvarout]
    # dy in [N,nvarin]
    dy_dx = torch.autograd.grad(
         y, x, torch.ones_like(y), 
        retain_graph=True, create_graph=True,)[0]
return dy_dx

for k in range(output.shape[1]):
    y =  output[:,k]
    dx = continuous_diff(x,y)
    # hardcoded for nvarin = 2 here
    dxx = continuous_diff(x, dx[:,0])
    dyy = continuous_diff(x, dx[:,1])
    grad2 = torch.concatenate([dxx, dyy], dim=-1)
    output_grad2.append(grad2)
output_grad2 = torch.stack(output_grad2, dim=-1)

有没有办法加快这个计算?

类似的问题已发布here,但自 2 年以来没有提出解决方案。

performance pytorch vectorization derivative autograd
1个回答
0
投票

当然!当前的实现遍历输出维度并单独计算梯度。我们可以通过使用矢量化操作并利用 PyTorch 中的批处理操作来加快计算速度。

这是您的代码的修改版本,可以更有效地计算向量函数的一阶和二阶导数:

import torch

def jacobian_vector_product(y, x, v, create_graph=False):
    # y: [N, nvarout]
    # x: [N, nvarin]
    # v: [N, nvarout]
    grad_outputs = v.detach()
    grad_outputs.requires_grad_(True)
    dy_dx = torch.autograd.grad(y, x, grad_outputs, retain_graph=True, create_graph=create_graph)
    return dy_dx[0]

def hessian_vector_product(y, x, v):
    # y: [N, nvarout]
    # x: [N, nvarin]
    # v: [N, nvarin]
    with torch.autograd.enable_grad():
        jvp = jacobian_vector_product(y, x, v, create_graph=True)
    hvp = jacobian_vector_product(jvp, x, v)
    return hvp

def compute_gradients(x, output):
    nvarin = x.shape[-1]
    nvarout = output.shape[-1]

    # Compute the first-order derivatives (Jacobian)
    eye = torch.eye(nvarout, dtype=x.dtype, device=x.device)
    eye = eye.view(1, nvarout, nvarout).expand(x.shape[0], -1, -1)
    jacobian = jacobian_vector_product(output, x, eye)  # [N, nvarin, nvarout]

    # Compute the second-order derivatives (Hessian)
    hessian = []
    for k in range(nvarin):
        basis = torch.zeros(nvarin, dtype=x.dtype, device=x.device)
        basis[k] = 1
        basis = basis.view(1, nvarin).expand(x.shape[0], -1)
        hvp = hessian_vector_product(output, x, basis)  # [N, nvarin, nvarout]
        hessian.append(hvp)

    hessian = torch.stack(hessian, dim=-1)  # [N, nvarin, nvarout, nvarin]

    return jacobian, hessian

然后您可以使用此函数来计算梯度:

x = torch.randn(150, 2)  # [N, nvarin]
output = torch.randn(150, 150)  # [N, nvarout]
jacobian, hessian = compute_gradients(x, output)

此代码通过矢量化操作并利用 PyTorch 中的批处理操作来更有效地计算 Jacobian 和 Hessian。

© www.soinside.com 2019 - 2024. All rights reserved.