我有以下简单的函数:

def f1(y_true, y_pred):
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

根据SCRICKIT-LEARN文档,f1_score的参数可以有以下类型:

  • y_true:1D数组,或标签指示器数组/稀疏矩阵
  • y_pred:1D数组,或标签指示器数组/稀疏矩阵

并且输出的类型为:

  • 浮点或浮点数组,Shape=[n_Unique_Labels]

如何向该函数添加类型提示,以便mypy不会出现错误?

我try 了以下不同的版本:

Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])

def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

但这造成了错误.

推荐答案

这是我用来避免类似的输入问题的方法.它利用了1.20中引入的numpy typing.ArrayLike类型涵盖List[float],因此不必担心显式涵盖它.

在上面运行带有NumPy v1.23.1的mypy v0.971没有显示任何问题.

from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics


def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
    return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}

y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)

assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)

Python-3.x相关问答推荐

Python多处理池:缺少一个进程

谁能解释一下这个带邮编的多功能环路?

如何强调您正在寻求以 pandas 数据帧的另一列为条件的差异?

将值从函数传递到标签

Python,Web 从交互式图表中抓取数据

如何在 histplot 中标记核密度估计

如何使用 django rest 框架在 self forienkey 中删除多达 n 种类型的数据?

使用条件参数为 super() 调用 __init__

没有可重定向到的 URL.提供一个 url 或在模型上定义一个 get_absolute_url 方法

python3:字节与字节数组,并转换为字符串和从字符串转换

使用 python 正则表达式匹配日期

numpy.ndarray 与 pandas.DataFrame

在 Ubuntu 上为 Python3 安装 mod_wsgi

Python3 - 如何从现有抽象类定义抽象子类?

python中的订单字典索引

使用完整路径激活 conda 环境

Python 3.4 多处理队列比 Pipe 快,出乎意料

带有 Emacs 的 Python 3

Python 无法处理以 0 开头的数字字符串.为什么?

如何使用 Celery 和 Django 将任务路由到不同的队列