我正在尝试优化 PyTorch 中的计算,首先识别张量的唯一元素,仅将昂贵的函数(例如 torch.exp)应用于这些唯一元素,然后将结果映射回原始张量的形状计算相对于原始张量的梯度。
我的动机是避免对输入张量中的重复值进行昂贵函数的冗余计算,这将显着提高性能。
这是演示我的方法的代码片段,但它会导致错误,因为 unique 不可微分:
import torch
inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)
unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
print(f"There are {unique_inputs.numel()} unique elements.")
unique_exp = torch.exp(unique_inputs)
full_exp = unique_exp[inverse_indices]
torch.autograd.grad(full_exp[0], inputs) # <-- Error here
还有其他方法可以做到这一点吗?
您可以使用自定义
backward
函数来完成此操作:
import torch
import torch.nn as nn
class UniqueForward(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs):
unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
unique_counts = torch.bincount(inverse_indices)
ctx.save_for_backward(inverse_indices, unique_counts)
ctx.input_shape = inputs.shape
return unique_inputs, inverse_indices
@staticmethod
def backward(ctx, grad_unique, grad_inverse):
inverse_indices, unique_counts = ctx.saved_tensors
grad_inputs = grad_unique[inverse_indices]
grad_inputs = grad_inputs / unique_counts[inverse_indices]
return grad_inputs
class EfficientFunction(nn.Module):
def __init__(self, function):
# function should be a callable element-wise function
super().__init__()
self.function = function
self.unique_forward = UniqueForward.apply
def forward(self, inputs):
unique_inputs, inverse_indices = self.unique_forward(inputs)
unique_results = self.function(unique_inputs)
full_results = unique_results[inverse_indices]
return full_results
EfficientFunction
模块应该适用于以下任何情况:1) inputs
具有重复值并且 2) 相关的 function
应用逐元素操作。
示例:
def expensive_function(x):
# use exp as example of expensive function
return torch.exp(x)
efficient_function = EfficientFunction(expensive_function)
# test efficient version
inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)
outputs = efficient_exp(inputs)
loss = outputs.mean()
loss.backward()
# test naive version
inputs2 = inputs.clone().detach().requires_grad_(True)
naive_results = expensive_function(inputs2)
naive_loss = naive_results.mean()
naive_loss.backward()
# compare gradients
torch.allclose(inputs.grad, inputs2.grad)