如何使用绑定类型变量和静态类型组合类型提示以获得最大的灵活性?

问题描述 投票:0回答:1

我想为一个简单的函数添加类型提示。由于它内部仅使用 numpy 调用,因此它的输入非常灵活。基本上,它接受所有类似数组的对象,其中有

numpy.typing.ArrayLike
类型。

但是定义返回类型并不那么简单。对于某些输入类型(例如列表),numpy 函数将返回值转换为 numpy 数组,这意味着我可以使用

-> np.NDArray

其他一些输入类型,例如

pandas.DataFrame
s,我在代码中大量使用。这通常是使用绑定到输入类型的 TypeVar 的一个很好的理由。

如何保持 numpy 的灵活性,同时提供有意义的类型提示,例如

mypy

注:示例方法用于计算声压级。

两个代码片段都是完全可行的代码,方法仅在类型提示上有所不同。

它们会导致静态类型检查器出现不同的错误,展示了每种方法的局限性。

版本1:

def get_decibels1(p2: npt.ArrayLike, clip_db: float=0) -> npt.NDArray:
    return (10 * np.log10(np.divide(p2, 4e-10)))

df = pd.DataFrame([[4, 5, 6], [7, 8, 9]])
get_decibels1(df).columns
# --- Causes mypy Error:
# error: "ndarray[Any, dtype[Any]]" has no attribute "columns"  [attr-defined]

版本2:

T = TypeVar('T', bound=npt.ArrayLike)
def get_decibels2(p2: T, clip_db: float=0) -> T:
    return (10 * np.log10(np.divide(p2, REF_P2)))

ls = [4.0, 5, 6]
get_decibels2(ls).shape
# --- Causes mypy error:
#  error: "list[float]" has no attribute "shape"  [attr-defined]

如何明智地结合这两种方法?


更新:

我想我也许可以用

@overload
来解决这个问题。但这似乎也不起作用,因为签名重叠

T = TypeVar('T', bound=Union[pd.DataFrame, pd.Series])
@overload
def get_decibels(p2: T) -> T: ...

@overload
def get_decibels(p2: npt.ArrayLike) -> npt.NDArray: ...

def get_decibels(p2: npt.ArrayLike):
    return (10 * np.log10(np.divide(p2, 4e-10)))
# --- Causes mypy error:
#  error: Overloaded function signatures 1 and 2 overlap with incompatible return types  [overload-overlap]

我的印象是 mypy 只是选择第一个匹配的签名,这样就可以解决问题。关于如何解决这个问题有什么想法吗?

python pandas numpy type-hinting
1个回答
0
投票

就个人而言,我宁愿让函数通过应用

NDArray
numpy.array
在所有情况下返回单一类型 (
numpy.asarray
)。此外,我相信
np.divide(a, b) 
a/b
是等价的,最后,当你可以扩展表格时为什么要覆盖
pandas.DataFrame
呢? (如果您有记忆问题,那么我理解,您可能想查看 polars 库)。还有,不需要
clip_db
吗?

虽然这是个人品味,但我的建议是一个始终返回 numpy 数组的函数:

import numpy as np
from numpy import typing as npt
import pandas as pd

def get_decibels(p2: npt.ArrayLike) -> npt.NDArray:
    p2 = np.asarray(p2)
    return 10 * np.log10(p2/4e-10)

def get_decibels_clip(p2: npt.ArrayLike, clip_db: float=0) -> npt.NDArray:
    p2 = np.asarray(p2)
    return np.maximum(10 * np.log10(p2/4e-10), clip_db)  # is this what was asked? 

df = pd.DataFrame([[4, 5, 6], [7, 8, 9]])
df[["db1", "db2", "db3"]] = get_decibels(df)
print(df) 
© www.soinside.com 2019 - 2024. All rights reserved.