接受 numpy 数组的类型提示函数

问题描述 投票:0回答:2

我无法理解为什么 mypy 在以下示例中会抛出错误。

import numpy as np
from typing import Sequence


def compute(x: Sequence[float]) -> bool:
    # some computation that does not modify x
    ...


compute(np.linspace(0, 1, 10))

Mypy错误:

Argument 1 to "compute" has incompatible type "ndarray[Any, dtype[floating[Any]]]"; expected "Sequence[float]"  [arg-type]

特别是,由于

typing.Sequence
需要可迭代性、可逆性和索引,我认为 numpy 数组也应该是
Sequence
。这与 numpy 数组是可变的,而
Sequence
类型是不可变的有关吗?我注意到当我将
Sequence
更改为
Iterable
时,问题就解决了。但我需要能够在
x
中索引
compute

那么,什么是类型提示

compute
函数的最佳方式,以便它可以接受具有可迭代性和索引的对象?

python numpy types mypy
2个回答
0
投票

尝试使用 np.ndarray 作为 x 的类型提示。我检查了

type(np.linspace(0, 1, 10))
,它返回了 np.ndarray


0
投票

np.ndarray
不是
Sequence
(与相应的协议不兼容),因为
np.ndarray
没有实现
.count
.index
方法,而 are required for
collections.abc.Sequence
, see the table of methods in link .

这里有一个相关问题.

为了使类似序列的东西与

np.ndarray
兼容,我建议定义你的协议:

from __future__ import annotations

from collections.abc import Collection, Reversible, Iterator
from typing import Protocol, TypeVar, overload

_T_co = TypeVar('_T_co', covariant=True)

class WeakSequence(Collection[_T_co], Reversible[_T_co], Protocol[_T_co]):
    @overload
    def __getitem__(self, index: int) -> _T_co: ...
    @overload
    def __getitem__(self, index: slice) -> WeakSequence[_T_co]: ...
    def __contains__(self, value: object) -> bool: ...
    def __iter__(self) -> Iterator[_T_co]: ...
    def __reversed__(self) -> Iterator[_T_co]: ...

这与

Sequence
定义 in typeshed 基本相同,除了我使用
Protocol
而不是
Generic
(以避免添加实现 - 存根不必这样做,另外这使得结构子类型化更加明显)并省略
.index
.count
方法。

现在

np.ndarray
应该与
WeakSequence
以及
list
tuple
兼容,您可以在注释中使用
WeakSequence
而不是
collections.abc.Sequence

© www.soinside.com 2019 - 2024. All rights reserved.