我试图在
cfunc
ted 函数中调用 njit
tion,但 Numba 没有 data_as()
方法来使其数组转换为双指针。谁能帮我弄清楚如何让它发挥作用吗?
import ctypes
import numpy as np
import numba as nb
@nb.cfunc(nb.types.void(
nb.types.CPointer(nb.types.double),
nb.types.CPointer(nb.types.double),
nb.types.intc,
nb.types.intc,
nb.types.intc,
nb.types.intc,
))
def get_param2(xn_, x_, idx, n, m1, m2):
in_array = nb.carray(x_, (n, m1, m2))
out_array = nb.carray(xn_, (m1, m2))
if idx >= n:
idx = n - 1
out_array[:, :] = in_array[idx]
def test_get_param(): # this one works
A = np.zeros((100, 2, 3))
Ai = np.empty((2, 3))
get_param2(
Ai.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
A.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
40,
*A.shape,
)
assert np.array_equal(A[40], Ai)
@nb.njit
def test_get_param_njit():
# this one fails with the error:
# `Unknown attribute 'data_as' of type ArrayCTypes(dtype=float64, ndim=2)`
A = np.zeros((100, 2, 3))
Ai = np.empty((2, 3))
get_param2(
Ai.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
A.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
40,
*A.shape,
)
assert np.array_equal(A[40], Ai)
您只需使用数组的
.ctypes
属性即可:
@nb.njit
def test_get_param_njit():
A = np.zeros((100, 2, 3))
Ai = np.empty((2, 3))
get_param2(
Ai.ctypes,
A.ctypes,
np.int32(40), # Convert to int32 for your signature
np.int32(100), np.int32(2), np.int32(3),
)
assert np.array_equal(A[40], Ai)