我想向接受
np.float32
数组或 np.float64
数组并返回相同类型的函数添加类型提示:
from typing import overload, Union
import numpy as np
import numpy.typing as npt
NPArray_FLOAT32 = npt.NDArray[np.float32]
NPArray_FLOAT64 = npt.NDArray[np.float64]
NPArray_FLOAT32_64 = Union[NPArray_FLOAT32, NPArray_FLOAT64]
@overload
def foo(xa: NPArray_FLOAT32, xb: NPArray_FLOAT32) -> NPArray_FLOAT32: ...
@overload
def foo(xa: NPArray_FLOAT64, xb: NPArray_FLOAT64) -> NPArray_FLOAT64: ...
def foo(xa: NPArray_FLOAT32_64, xb: NPArray_FLOAT32_64) -> NPArray_FLOAT32_64:
# ...
但是,这会导致以下错误
mypy
mypy [overload-overlap]: Overloaded function signatures 1 and 2 overlap with incompatible return types.
正确的做法是什么?这几乎看起来像是
mypy
的错误,因为 np.float32
不与 np.float64
重叠。
解决方案是使用TypeVar:
from typing import TypeVar, overload
import numpy as np
from numpy.typing import NDArray
T = TypeVar("T", np.float32, np.float64)
def add_32_64(xa: NDArray[T], xb: NDArray[T]) -> NDArray[T]:
return xa + xb
# no issues when calling `add_32_64` with narrower types
def add_64(xa: NDArray[np.float64], xb: NDArray[np.float64]) -> NDArray[np.float64]:
return add_32_64(xa, xb)
由于它从 nanobind 公开了 C++ 函数,因此这与函数定义非常匹配,因为这些函数也在 C++ 中定义为:
// NDArray defined earlier
template <typename T>
NDArray<T, 2>
foo(const NDArray<const T, 2> xa, const NDArray<const T, 2> xb) {
感谢如何输入提示通用numpy数组?(感谢Nick ODell)和insync解决了。感谢您的帮助!