Numba 中的稀疏矩阵

问题描述 投票:0回答:3

我希望使用 Numba (http://numba.pydata.org/) 加速我的机器学习算法(用 Python 编写)。请注意,该算法将稀疏矩阵作为输入数据。在我的纯Python实现中,我使用了Scipy中的csr_matrix和相关类,但显然它与Numba的JIT编译器不兼容。

我还创建了自己的自定义类来实现稀疏矩阵(基本上是(索引,值)对的列表),但它又与 Numba 不兼容(即,我收到一些奇怪的错误消息,说它不无法识别扩展类型)

是否有一种替代的、简单的方法来仅使用与 Numba 兼容的 numpy(不诉诸 SciPy)来实现稀疏矩阵?任何示例代码将不胜感激。谢谢!

python numpy scipy anaconda numba
3个回答
7
投票

如果您所要做的就是迭代 CSR 矩阵的值,则可以将属性 data、indptr 和索引传递给函数而不是 CSR 矩阵对象。

from scipy import sparse
from numba import njit

@njit
def print_csr(A, iA, jA):
    for row in range(len(iA)-1):
        for i in range(iA[row], iA[row+1]):
            print(row, jA[i], A[i])

A = sparse.csr_matrix([[1, 2, 0], [0, 0, 3], [4, 0, 5]])
print_csr(A.data, A.indptr, A.indices)

4
投票

您可以将稀疏矩阵的数据作为纯 numpy 或 python 访问。 例如

M=sparse.csr_matrix([[1,0,0],[1,0,1],[1,1,1]])
ML = M.tolil()

for d,r in enumerate(zip(ML.data,ML.rows))
    # d,r are lists
    dr = np.array([d,r])
    print dr

产生:

[[1]
 [0]]
[[1 1]
 [0 2]]
[[1 1 1]
 [0 1 2]]

numba 当然可以处理使用这些数组的代码,当然,前提是它不期望每行具有相同大小的数组。


lil
格式存储值2个对象数据类型数组,其中数据和索引按行存储列表。


0
投票

我知道这是一篇旧帖子,但我,从未来来看,这个问题在 numba 中仍然没有得到解决。

因为我需要它,所以我为此创建了一个库,任何需要它的人都可以使用https://github.com/PessoaP/smn/

© www.soinside.com 2019 - 2024. All rights reserved.