如何使用torch.unique过滤重复值,计算一个昂贵的函数,将其映射回来,然后计算梯度?

问题描述 投票:0回答:1

我正在尝试优化 PyTorch 中的计算,首先识别张量的唯一元素,仅将昂贵的函数(例如 torch.exp)应用于这些唯一元素,然后将结果映射回原始张量的形状计算相对于原始张量的梯度。

我的动机是避免对输入张量中的重复值进行昂贵函数的冗余计算,这将显着提高性能。

这是演示我的方法的代码片段,但它会导致错误,因为 unique 不可微分:

import torch

inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)

unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
print(f"There are {unique_inputs.numel()} unique elements.")

unique_exp = torch.exp(unique_inputs)
full_exp = unique_exp[inverse_indices]

torch.autograd.grad(full_exp[0], inputs) # <-- Error here

还有其他方法可以做到这一点吗?

pytorch torch automatic-differentiation
1个回答
0
投票

您可以使用自定义

backward
函数来完成此操作:

import torch
import torch.nn as nn

class UniqueForward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        unique_inputs, inverse_indices = torch.unique(inputs, return_inverse=True)
        unique_counts = torch.bincount(inverse_indices)
        ctx.save_for_backward(inverse_indices, unique_counts)
        ctx.input_shape = inputs.shape
        return unique_inputs, inverse_indices

    @staticmethod
    def backward(ctx, grad_unique, grad_inverse):
        inverse_indices, unique_counts = ctx.saved_tensors
        grad_inputs = grad_unique[inverse_indices]
        grad_inputs = grad_inputs / unique_counts[inverse_indices]
        return grad_inputs

class EfficientFunction(nn.Module):
    def __init__(self, function):
        # function should be a callable element-wise function
        super().__init__()
        self.function = function
        self.unique_forward = UniqueForward.apply
        
    def forward(self, inputs):
        unique_inputs, inverse_indices = self.unique_forward(inputs)
        unique_results = self.function(unique_inputs)
        full_results = unique_results[inverse_indices]
        return full_results

EfficientFunction
模块应该适用于以下任何情况:1)
inputs
具有重复值并且 2) 相关的
function
应用逐元素操作。

示例:

def expensive_function(x):
    # use exp as example of expensive function
    return torch.exp(x)

efficient_function = EfficientFunction(expensive_function)

# test efficient version
inputs = torch.rand(100)
inputs = torch.round(inputs, decimals=2)
inputs.requires_grad_(True)

outputs = efficient_exp(inputs)
loss = outputs.mean()
loss.backward()

# test naive version
inputs2 = inputs.clone().detach().requires_grad_(True)
naive_results = expensive_function(inputs2)
naive_loss = naive_results.mean()
naive_loss.backward()

# compare gradients 
torch.allclose(inputs.grad, inputs2.grad)
© www.soinside.com 2019 - 2024. All rights reserved.