无符号整数 numpy 数组的正确类型提示是什么?

问题描述 投票:0回答:1

我想注释一些接收无类型整数的 numpy 数组的函数,对它们执行某些操作,然后返回相同数据类型的 numpy 数组。但我很难对它们进行注释。我知道如何使用

TypeVar
来注释任意 dtype 的数组。但这些函数对于浮点值没有意义,所以我想将输入限制为无符号整数。我还知道如何将输入限制为非常特定的数据类型,例如使用
npt.NDArray[np.ushort]
。但这太具体了。我正在搜索中间类型注释。但是当我在 mypy 上测试它们时,我的所有想法(参见下面的代码)都会出错(参见下面的错误)。

我想出了以下

TypeVar

from typing import TypeVar
from typing import Tuple


import numpy as np
import numpy.typing as npt


# Works, but is too unspecific.
T = TypeVar("T", bound=np.generic)


def subsample_array_T(
    x: npt.NDArray[T], grid: Tuple[slice, ...]
) -> npt.NDArray[T]:
    """Takes arrays of any dtype and returns array of same dtype."""
    return x[grid]


# Works, but is too specific.
def subsample_array_ushort(
    x: npt.NDArray[np.ushort], grid: Tuple[slice, ...]
) -> npt.NDArray[np.ushort]:
    """Takes arrays of any dtype and returns array of same dtype."""
    return x[grid]


# Test 1
# error: Missing type parameters for generic type "unsignedinteger"
S = TypeVar("S", bound=np.unsignedinteger)


def subsample_unsigned_S(
    x: npt.NDArray[S], grid: Tuple[slice, ...]
) -> npt.NDArray[S]:
    """Take only arrays of unsigned dtype and return the same dtype."""
    return x[grid]


# Test 2
# error: Type argument "U" of "NDArray" must be a subtype of "generic"
U = TypeVar(
    "U",
    npt.NDArray[np.ubyte],
    npt.NDArray[np.ushort],
    npt.NDArray[np.uintc],
    npt.NDArray[np.uint],
    npt.NDArray[np.ulonglong],
)


def subsample_unsigned_U(
    x: npt.NDArray[U], grid: Tuple[slice, ...]
) -> npt.NDArray[U]:
    """Take only arrays of unsigned dtype and return the same dtype."""
    return x[grid]


# Test 3
# error: TypeVar cannot have both values and an upper bound
V = TypeVar(
    "V",
    npt.NDArray[np.ubyte],
    npt.NDArray[np.ushort],
    npt.NDArray[np.uintc],
    npt.NDArray[np.uint],
    npt.NDArray[np.ulonglong],
    bound=np.generic,
)


# subsequent errors from error above...
def subsample_unsigned_V(
    x: npt.NDArray[V], grid: Tuple[slice, ...]
) -> npt.NDArray[V]:
    """Take only arrays of unsigned dtype and return the same dtype."""
    return x[grid]

但是它们都会在 mypy 中产生错误:

numpy_unsigned_typing.py:30:24: error: Missing type parameters for generic type "unsignedinteger"  [type-arg]
    S = TypeVar("S", bound=np.unsignedinteger)
                           ^
numpy_unsigned_typing.py: note: In function "subsample_unsigned_U":
numpy_unsigned_typing.py:53: error: Type argument "U" of "NDArray" must be a subtype of "generic"  [type-var]
        x: npt.NDArray[U], grid: Tuple[slice, ...]
        ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
numpy_unsigned_typing.py: note: At top level:
numpy_unsigned_typing.py:61:1: error: TypeVar cannot have both values and an upper bound  [misc]
    V = TypeVar(
    ^
numpy_unsigned_typing.py: note: In function "subsample_unsigned_V":
numpy_unsigned_typing.py:74:20: error: Variable "numpy_unsigned_typing.V" is not valid as a type  [valid-type]
        x: npt.NDArray[V], grid: Tuple[slice, ...]
                       ^
numpy_unsigned_typing.py:74:20: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases
numpy_unsigned_typing.py:75:18: error: Variable "numpy_unsigned_typing.V" is not valid as a type  [valid-type]
    ) -> npt.NDArray[V]:
                     ^
numpy_unsigned_typing.py:75:18: note: See https://mypy.readthedocs.io/en/stable/common_issues.html#variables-vs-type-aliases

我想以

的形式使用无符号数据类型的
Union

UnsignedIntegerArray = Union[
    npt.NDArray[np.ubyte],
    npt.NDArray[np.ushort],
    npt.NDArray[np.uintc],
    npt.NDArray[np.uint],
    npt.NDArray[np.ulonglong],
]

但是参数的类型注释和返回值之间没有任何联系。同样,仅将

Union
放在 dtype 周围并放入
npt.NDArray
中也不起作用,正如here所讨论的那样。

我想最有希望的方法是使用

Test 1
并添加第二个
TypeVar
来描述位。至少,上面的
UnsignedIntegerArray
联合解析为下面的
ndarray[Any, dtype[unsignedinteger[_8Bit]]] | ndarray[Any, dtype[unsignedinteger[_16Bit]]] | ndarray[Any, dtype[unsignedinteger[_32Bit]]] | ndarray[Any, dtype[unsignedinteger[_64Bit]]] | ndarray[Any, dtype[unsignedinteger[_64Bit]]]
。所以我想做一些类似的事情:

T = TypeVar("T", np._8Bit, np._16Bit, np._32Bit, np._64Bit)
S = TypeVar("S", bound=np.unsignedinteger[T])

但这给了我像

Name "np._8Bit" is not defined
这样的错误。而且我不知道这个
np._8Bit
如何解决。

python numpy mypy python-typing
1个回答
0
投票

您第三次尝试使用上限为

np.unsignedinteger
的类型变量几乎是正确的。您只是缺少类型参数。

这有效:

from typing import TypeVar

import numpy as np
from numpy.typing import NBitBase, NDArray


T = TypeVar("T", bound=np.unsignedinteger[NBitBase])


def subsample_array(
    arr: NDArray[T],
    grid: tuple[slice, ...],
) -> NDArray[T]:
    return arr[grid]


a: NDArray[np.ushort]
out = subsample_array(a, (slice(1), ))
reveal_type(a)    # note: Revealed type is "numpy.ndarray[Any, numpy.dtype[numpy.unsignedinteger[Any]]]"
reveal_type(out)  # note: Revealed type is "numpy.ndarray[Any, numpy.dtype[numpy.unsignedinteger[Any]]]"

不幸的是,由于Python类型systemd(尚)不支持更高级的类型变量,因此无法使类型在位长度方面变得通用。

© www.soinside.com 2019 - 2024. All rights reserved.