argmax() 函数是否有可微的近似值

问题描述 投票:0回答:1

输出一个向量,对应输入向量中最大值的位置接近1,其他位置接近0。这个函数需要可微。

例如:

import torch
a = torch.Tensor([2,3,4,5,4,3,2], require_grad=True)
b = f(a)
>>> b
tensor([1e-10, 1e-10, 1e-10, 0.9999999, 1e-10, 1e-10, 1e-10], grad_fn=<...>)

有没有实用的方法?

pytorch torch argmax
1个回答
0
投票

有几件事需要注意:

  1. 您提供
    integers
    作为
    torch.Tensor
    的输入,并且整数不可微分。它应该是浮点数、双精度数或复数。
  2. 其次,
    torch.Tensor
    是一个基类,它创建一个包含单个dtype元素的空张量;它没有
    requires_grad
    参数。对此您将有两种解决方案:
    • 您可以使用
      torch.tensor([1., 2.], requires_grad=True)
    • 或者如果你想检查张量是否有梯度,你可以使用
      Tensor.requires_grad

针对您的问题,我认为您可以通过两种方式实现这一目标:

  1. 直接使用 torch.argmax 函数:
    >>> import torch
    >>> a = torch.tensor([2.,3.,4.,5.,4.,3.,2.], requires_grad=True)
    >>> a.dtype
    torch.float32
    >>> b = torch.argmax(a)
    >>> b
    tensor(3)
    >>> b.grad
    >>> 
    
    它运行良好并且可微(所有 PyTorch 函数都是可微的)。
  2. 创建您自己的自定义函数
    >>> import torch
    >>> def f(x, epi=1e-6):
    ...     max, _ = torch.max(x, dim=0)
    ...     num = torch.exp(x - max)
    ...     den = torch.sum(torch.exp(x - max)) + epi
    ...     return num/den
    ... 
    >>> a = torch.tensor([2., 3., 4., 5., 4., 3., 2.], requires_grad=True)
    >>> b = f(a)
    >>> b
    tensor([0.0236, 0.0643, 0.1747, 0.4748, 0.1747, 0.0643, 0.0236],
           grad_fn=<DivBackward0>)
    

希望对您有帮助。谢谢!

© www.soinside.com 2019 - 2024. All rights reserved.