我想为一个简单的函数添加类型提示。由于它内部仅使用 numpy 调用,因此它的输入非常灵活。基本上,它接受所有类似数组的对象,其中有
numpy.typing.ArrayLike
类型。
但是定义返回类型并不那么简单。对于某些输入类型(例如列表),numpy 函数将返回值转换为 numpy 数组,这意味着我可以使用
-> np.NDArray
。
其他一些输入类型,例如
pandas.DataFrame
s,我在代码中大量使用。这通常是使用绑定到输入类型的 TypeVar 的一个很好的理由。
如何保持 numpy 的灵活性,同时提供有意义的类型提示,例如
mypy
?
注:示例方法用于计算声压级。
两个代码片段都是完全可行的代码,方法仅在类型提示上有所不同。
它们会导致静态类型检查器出现不同的错误,展示了每种方法的局限性。
版本1:
def get_decibels1(p2: npt.ArrayLike, clip_db: float=0) -> npt.NDArray:
return (10 * np.log10(np.divide(p2, 4e-10)))
df = pd.DataFrame([[4, 5, 6], [7, 8, 9]])
get_decibels1(df).columns
# --- Causes mypy Error:
# error: "ndarray[Any, dtype[Any]]" has no attribute "columns" [attr-defined]
版本2:
T = TypeVar('T', bound=npt.ArrayLike)
def get_decibels2(p2: T, clip_db: float=0) -> T:
return (10 * np.log10(np.divide(p2, REF_P2)))
ls = [4.0, 5, 6]
get_decibels2(ls).shape
# --- Causes mypy error:
# error: "list[float]" has no attribute "shape" [attr-defined]
如何明智地结合这两种方法?
更新:
我想我也许可以用
@overload
来解决这个问题。但这似乎也不起作用,因为签名重叠。
T = TypeVar('T', bound=Union[pd.DataFrame, pd.Series])
@overload
def get_decibels(p2: T) -> T: ...
@overload
def get_decibels(p2: npt.ArrayLike) -> npt.NDArray: ...
def get_decibels(p2: npt.ArrayLike):
return (10 * np.log10(np.divide(p2, 4e-10)))
# --- Causes mypy error:
# error: Overloaded function signatures 1 and 2 overlap with incompatible return types [overload-overlap]
我的印象是 mypy 只是选择第一个匹配的签名,这样就可以解决问题。关于如何解决这个问题有什么想法吗?
就个人而言,我宁愿让函数通过应用
NDArray
或 numpy.array
在所有情况下返回单一类型 (numpy.asarray
)。此外,我相信 np.divide(a, b)
和 a/b
是等价的,最后,当你可以扩展表格时为什么要覆盖 pandas.DataFrame
呢? (如果您有记忆问题,那么我理解,您可能想查看 polars 库)。还有,不需要clip_db
吗?
虽然这是个人品味,但我的建议是一个始终返回 numpy 数组的函数:
import numpy as np
from numpy import typing as npt
import pandas as pd
def get_decibels(p2: npt.ArrayLike) -> npt.NDArray:
p2 = np.asarray(p2)
return 10 * np.log10(p2/4e-10)
def get_decibels_clip(p2: npt.ArrayLike, clip_db: float=0) -> npt.NDArray:
p2 = np.asarray(p2)
return np.maximum(10 * np.log10(p2/4e-10), clip_db) # is this what was asked?
df = pd.DataFrame([[4, 5, 6], [7, 8, 9]])
df[["db1", "db2", "db3"]] = get_decibels(df)
print(df)