pytorch中的mask_fill在这里做什么

问题描述 投票:0回答:1

我有一个名为“k1”的张量,其形状为 3,1,1,9,还有 p1 张量,其形状为 3,7,9,9,我想知道下面的行是做什么的?

p1 = p1 .masked_fill(k1== 0, float("-1e30"))
python pytorch
1个回答
2
投票

正如文档页面所描述的:

Tensor.masked_fill(mask, value)

用值填充
self
张量的元素,其中
mask
True
mask
的形状必须可以与底层张量的形状一起广播。

在您的情况下,它将在

p1
等于 0 的位置将
float("-1e30")
的值放入
k1
中。由于
k1
具有单一维度,因此其形状将被广播为
p1
的形状。

© www.soinside.com 2019 - 2024. All rights reserved.