如何向 scikit-learn 函数添加类型提示?

问题描述 投票:0回答:2

我有以下简单的功能:

def f1(y_true, y_pred):
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

根据 scikit-learn 文档,

f1_score
的参数可以有以下类型:

  • y_true
    :一维数组,或标签指示符数组/稀疏矩阵
  • y_pred
    :一维数组,或标签指示符数组/稀疏矩阵

输出的类型为:

  • 浮点或浮点数组,形状 = [n_unique_labels]

如何向此函数添加类型提示,以便 mypy 不会抱怨?

我尝试了以下变体:

Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])

def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

但这给出了错误。

python scikit-learn mypy python-typing
2个回答
2
投票

这是我用来避免类似 mypy 问题的方法。它利用了 1.20 中引入的 numpytyping

ArrayLike
类型覆盖了
List[float]
,因此无需担心显式覆盖它。

使用 numpy v1.23.1 运行 mypy v0.971 没有显示任何问题。

from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics


def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)

assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)

0
投票

而不是

Array1D = NewType("Array1D", Union[np.ndarray, List[np.float64]])

你可以使用

Array1D = Union[np.ndarray, List[np.float64]]
© www.soinside.com 2019 - 2024. All rights reserved.