我有以下简单的功能:
def f1(y_true, y_pred):
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
根据 scikit-learn 文档,
f1_score
的参数可以有以下类型:
y_true
:一维数组,或标签指示符数组/稀疏矩阵y_pred
:一维数组,或标签指示符数组/稀疏矩阵输出的类型为:
如何向此函数添加类型提示,以便 mypy 不会抱怨?
我尝试了以下变体:
Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])
def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
但这给出了错误。
这是我用来避免类似 mypy 问题的方法。它利用了 1.20 中引入的 numpytyping。
ArrayLike
类型覆盖了 List[float]
,因此无需担心显式覆盖它。
使用 numpy v1.23.1 运行 mypy v0.971 没有显示任何问题。
from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics
def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)
assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)
而不是
Array1D = NewType("Array1D", Union[np.ndarray, List[np.float64]])
你可以使用
Array1D = Union[np.ndarray, List[np.float64]]