是否有一种很好的方法来检查numpy数组和割炬张量是否指向相同的基础数据?

问题描述 投票:2回答:1

我想检查numpy数组和割炬张量是否指向相同的基础内存。到目前为止,我想出了一个简单的检查方法,但是看起来并不优雅。

import numpy as np
import torch

# example 
a = np.random.randn(3,3)
b = torch.from_numpy(a)

assert a.__array_interface__['data'][0] == b.data_ptr()

有更好的方法吗?另外,如果使用此断言,是否可能发生某些潜在的未定义/错误行为?

预先感谢您的回答:)

python numpy memory-management pytorch
1个回答
1
投票

这是访问和比较指针的完全有效的方法。数组接口旨在允许共享数据缓冲区,因此它将具有正确的指针。话虽如此,如果您更喜欢一个不太冗长的解决方案,也可以像这样直接获取它:

import numpy as np
import torch
​
# example 
a = np.random.randn(3,3)
b = torch.from_numpy(a)
​
print(a.ctypes.data)
print(b.data_ptr())
140413464706720
140413464706720
© www.soinside.com 2019 - 2024. All rights reserved.