使用 torch.where 对张量进行阈值处理是否会将其张量从计算图中分离出来?

问题描述 投票:0回答:1

我正在 PyTorch 中编写一个自定义损失函数,用于多类语义分割。该函数的一部分是对张量中选择的通道进行阈值处理,这些通道用 tracker_index 表示。

作为计算图一部分的函数的最后一部分是channel_tensor,如果我注释掉应用 torch.where 的行,一切都会顺利进行。我尝试将 1 和 0 设置为 float32 张量,并确保它们与 channel_tensor 位于同一设备上,这使我相信八个阈值是不可微分的,因此不能成为损失函数或 torch.where 的一部分总是将张量从计算图中分离出来。请指教。

channel_tensor =torch.select(
     segmentation_output,
     dim=-3,
     index=tracker_index
)
channels[tracker_index]= torch.where(channel_tensor > self.threshold, torch.tensor(1, device=channel_tensor.device, dtype=torch.float32), torch.tensor(0, device=channel_tensor.device, dtype=torch.float32))
python machine-learning deep-learning pytorch neural-network
1个回答
0
投票

不,但是...

torch.where(...)
不会从计算图中分离任何内容。

torch.where(cond, a, b)
a
具有相同的梯度,其中
cond
True
,并且与
b
相同,其中
cond
False

(所以本质上,如果

c = torch.where(cond, a, b)
c.grad
就是
torch.where(cond, a.grad, b.grad)

在你的情况下,

a
b
是常数,所以所有这些梯度都是0,这有效地从图表中删除了结果。

你说你的操作是“阈值化”,但是那不是你正在做的!

阈值将保持该值,除非它高于(或低于)某个阈值。您所做的是将低于阈值的值设置为

0
,将高于阈值的值设置为 1,这是一个 Heaviside 阶跃函数。它几乎在任何地方都是可微的,但是定义时其梯度始终为 0(因此无法用于优化目的)

© www.soinside.com 2019 - 2024. All rights reserved.