Numpy 类型重载

问题描述 投票:0回答:1

我想向接受

np.float32
数组或
np.float64
数组并返回相同类型的函数添加类型提示:

from typing import overload, Union

import numpy as np
import numpy.typing as npt

NPArray_FLOAT32 = npt.NDArray[np.float32]
NPArray_FLOAT64 = npt.NDArray[np.float64]
NPArray_FLOAT32_64 = Union[NPArray_FLOAT32, NPArray_FLOAT64]

@overload
def foo(xa: NPArray_FLOAT32, xb: NPArray_FLOAT32) -> NPArray_FLOAT32: ...


@overload
def foo(xa: NPArray_FLOAT64, xb: NPArray_FLOAT64) -> NPArray_FLOAT64: ...


def foo(xa: NPArray_FLOAT32_64, xb: NPArray_FLOAT32_64) -> NPArray_FLOAT32_64:
    # ...

但是,这会导致以下错误

mypy

mypy [overload-overlap]: Overloaded function signatures 1 and 2 overlap with incompatible return types.

正确的做法是什么?这几乎看起来像是

mypy
的错误,因为
np.float32
不与
np.float64
重叠。

python numpy mypy python-typing
1个回答
0
投票

解决方案是使用TypeVar:

from typing import TypeVar, overload
import numpy as np
from numpy.typing import NDArray

T = TypeVar("T", np.float32, np.float64)


def add_32_64(xa: NDArray[T], xb: NDArray[T]) -> NDArray[T]:
    return xa + xb

# no issues when calling `add_32_64` with narrower types 
def add_64(xa: NDArray[np.float64], xb: NDArray[np.float64]) -> NDArray[np.float64]:
    return add_32_64(xa, xb)

由于它从 nanobind 公开了 C++ 函数,因此这与函数定义非常匹配,因为这些函数也在 C++ 中定义为:

// NDArray defined earlier

template <typename T>
NDArray<T, 2>
foo(const NDArray<const T, 2> xa, const NDArray<const T, 2> xb) {
信用

感谢如何输入提示通用numpy数组?(感谢Nick ODell)和insync解决了。感谢您的帮助!

© www.soinside.com 2019 - 2024. All rights reserved.