我有一个像这样的火炬张量:
a=[1, 234, 54, 6543, 55, 776]
以及其他张量,如下所示:
b=[234, 54]
c=[55, 776]
我想创建一个新的掩模张量,如果有另一个张量(
a
或b
)等于它,则c
的值将为真。a_masked =[False, True, True, False, True, True]
# The first two True values correspond to tensor `b` while the last two True values
correspond to tensor `c`.
我见过其他方法来检查完整张量是否包含在另一个张量中,但这里不是这种情况。
有没有一种火炬方式可以有效地做到这一点?
谢谢!
根据 PyTorch 论坛 here 上的答案,看起来您只需要一个显式的 for 循环,例如,
import torch
a = torch.tensor([1, 234, 54, 6543, 55, 776])
b = torch.tensor([234, 54])
a_masked = sum(a == i for i in b).bool()
print(a_masked)
tensor([False, True, True, False, False, False])