带有布尔型numpy数组的Pytorch掩模张量

问题描述 投票:0回答:1

我有一个名为84x84target pytorch张量。我需要使用由84x84True组成的False布尔值numpy数组对其进行遮罩。

[执行target = target[mask]时,出现错误TypeError: can't convert np.ndarray of type numpy.bool_. The only supported types are: double, float, float16, int64, int32, and uint8.

令人惊讶的是,只有在GPU上运行时,才会出现此错误。在CPU上运行时,一切正常。我该如何解决?

python numpy pytorch torch
1个回答
0
投票

我认为这些类型有些混乱。但这有效。

import torch
tensor = torch.randn(84,84)
c = torch.randn(tensor.size()).bool()
c[1, 2:5] = False
x = tensor[c].size()

为了测试,我创建了具有随机值的张量。之后,将3个元素设置为False。在最后一步中,我看得到由84 ^ 2-3产生的7053。

希望有帮助的方式。

© www.soinside.com 2019 - 2024. All rights reserved.