给定
before
和 after
张量,我想在不使用循环的情况下用 before
替换另一个张量 A
中 after
的所有实例。
示例:
before = torch.Tensor([2,4,5])
after = torch.Tensor([20,40,50])
A = torch.Tensor([1,2,3,4,5,6])
result = replace(A, before, after)
结果应该是
torch.Tensor([1,20,3,40,50,6])
。
这不是最有效的,但你可以这样做:
In [37]: mask = (A.repeat( 3,1 ).t() == before).t().sum(0).to(torch.bool)
In [38]: torch.where(~mask, A, after[mask.to(torch.int).cumsum(0)-1] )
Out[38]: tensor([ 1, 20, 3, 40, 50, 6])
在这种情况下,您将具体化每个可能的值来检查并对匹配项求和以查看是否存在。它不是循环,因此也许可以利用 GPU,但它本身并不是更好的复杂性。
您真正想要的效率是集合操作,但我想不出 pytorch 中的任何操作可以为您提供所需的索引。您可以尝试使用
torch.unique(... return_counts=True)
找到一些替代方案,或者搜索有关 pytorch 张量交集的其他问题,看看是否有办法更有效地构造掩模。