我需要计算
output
(比如150)中的矢量函数batch x nvarout
相对于x
(比如2)中的输入batch x nvarin
的一阶和二阶导数。
我设法使用以下代码做到这一点:
def continuous_diff(x, y):
torch.set_grad_enabled(True)
x.requires_grad_(True)
# x in [N,nvarin]
# y in [N,nvarout]
# dy in [N,nvarin]
dy_dx = torch.autograd.grad(
y, x, torch.ones_like(y),
retain_graph=True, create_graph=True,)[0]
return dy_dx
和
for k in range(output.shape[1]):
y = output[:,k]
dx = continuous_diff(x,y)
# hardcoded for nvarin = 2 here
dxx = continuous_diff(x, dx[:,0])
dyy = continuous_diff(x, dx[:,1])
grad2 = torch.concatenate([dxx, dyy], dim=-1)
output_grad2.append(grad2)
output_grad2 = torch.stack(output_grad2, dim=-1)
有没有办法加快这个计算?
类似的问题已发布here,但自 2 年以来没有提出解决方案。
当然!当前的实现遍历输出维度并单独计算梯度。我们可以通过使用矢量化操作并利用 PyTorch 中的批处理操作来加快计算速度。
这是您的代码的修改版本,可以更有效地计算向量函数的一阶和二阶导数:
import torch
def jacobian_vector_product(y, x, v, create_graph=False):
# y: [N, nvarout]
# x: [N, nvarin]
# v: [N, nvarout]
grad_outputs = v.detach()
grad_outputs.requires_grad_(True)
dy_dx = torch.autograd.grad(y, x, grad_outputs, retain_graph=True, create_graph=create_graph)
return dy_dx[0]
def hessian_vector_product(y, x, v):
# y: [N, nvarout]
# x: [N, nvarin]
# v: [N, nvarin]
with torch.autograd.enable_grad():
jvp = jacobian_vector_product(y, x, v, create_graph=True)
hvp = jacobian_vector_product(jvp, x, v)
return hvp
def compute_gradients(x, output):
nvarin = x.shape[-1]
nvarout = output.shape[-1]
# Compute the first-order derivatives (Jacobian)
eye = torch.eye(nvarout, dtype=x.dtype, device=x.device)
eye = eye.view(1, nvarout, nvarout).expand(x.shape[0], -1, -1)
jacobian = jacobian_vector_product(output, x, eye) # [N, nvarin, nvarout]
# Compute the second-order derivatives (Hessian)
hessian = []
for k in range(nvarin):
basis = torch.zeros(nvarin, dtype=x.dtype, device=x.device)
basis[k] = 1
basis = basis.view(1, nvarin).expand(x.shape[0], -1)
hvp = hessian_vector_product(output, x, basis) # [N, nvarin, nvarout]
hessian.append(hvp)
hessian = torch.stack(hessian, dim=-1) # [N, nvarin, nvarout, nvarin]
return jacobian, hessian
然后您可以使用此函数来计算梯度:
x = torch.randn(150, 2) # [N, nvarin]
output = torch.randn(150, 150) # [N, nvarout]
jacobian, hessian = compute_gradients(x, output)
此代码通过矢量化操作并利用 PyTorch 中的批处理操作来更有效地计算 Jacobian 和 Hessian。