使用 numpy 数组参数从 njitted 函数调用 Numba cfunc

问题描述 投票:0回答:1

我试图在

cfunc
ted 函数中调用
njit
tion,但 Numba 没有
data_as()
方法来使其数组转换为双指针。谁能帮我弄清楚如何让它发挥作用吗?

import ctypes
import numpy as np
import numba as nb

@nb.cfunc(nb.types.void(
                nb.types.CPointer(nb.types.double),
                nb.types.CPointer(nb.types.double),
                nb.types.intc,
                nb.types.intc,
                nb.types.intc,
                nb.types.intc,
            ))
def get_param2(xn_, x_, idx, n, m1, m2):
    in_array = nb.carray(x_, (n, m1, m2))
    out_array = nb.carray(xn_, (m1, m2))
    if idx >= n:
        idx = n - 1
    out_array[:, :] = in_array[idx]


def test_get_param(): # this one works

    A = np.zeros((100, 2, 3))
    Ai = np.empty((2, 3))
    get_param2(
        Ai.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        A.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        40,
        *A.shape,
    )

    assert np.array_equal(A[40], Ai)

@nb.njit
def test_get_param_njit(): 
# this one fails with the error:
#     `Unknown attribute 'data_as' of type ArrayCTypes(dtype=float64, ndim=2)`

    A = np.zeros((100, 2, 3))
    Ai = np.empty((2, 3))
    get_param2(
        Ai.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        A.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
        40,
        *A.shape,
    )

    assert np.array_equal(A[40], Ai)
python numba
1个回答
0
投票

您只需使用数组的

.ctypes
属性即可:

@nb.njit
def test_get_param_njit(): 

    A = np.zeros((100, 2, 3))
    Ai = np.empty((2, 3))
    get_param2(
        Ai.ctypes,
        A.ctypes,
        np.int32(40), # Convert to int32 for your signature
        np.int32(100), np.int32(2), np.int32(3),
    )

    assert np.array_equal(A[40], Ai)
© www.soinside.com 2019 - 2024. All rights reserved.