如何编写与 Python、Numpy 或 Pandas 参数一起使用并返回相同类型的通用 Python 函数

问题描述 投票:0回答:1

编写可与 float、Numpy 或 Pandas 数据类型一起使用并且始终返回与给定参数相同类型的 python 函数的最佳方法是什么。 问题是,计算包括一个或多个浮点值。

例如玩具示例:

def mycalc(x, a=1.0, b=1.0):
    return a * x + b

(我在这里大大简化了问题,因为我理想情况下希望有多个输入参数,例如

x
,但是您可以假设该函数是矢量化的,因为它可以与 Numpy 数组参数和 Pandas 系列一起使用).

对于 Numpy 数组和 Pandas 系列,这工作得很好,因为数据类型是由输入参数决定的。

import numpy as np
x = np.array([1, 2, 3], dtype="float32")
print(mycalc(x).dtype)  # float32
import pandas as pd
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
print(mycalc(x).dtype)  # float32

但是当使用较低精度的 numpy 浮点数时,类型会“提升”为 float64,大概是由于公式中的 float 参数:

x = np.float32(1.0)
print(mycalc(x).dtype)  # float64

理想情况下,我希望该函数能够使用 Python 浮点数、numpy 标量、numpy 数组、Pandas 系列、Jax 数组,如果可能的话甚至可以使用 Sympy 符号变量。

但我不想用太多额外的语句来处理每种情况,从而使函数变得混乱。

我尝试过这个,它可以与 Numpy 标量一起使用,但当您提供数组或系列时会中断:

def mycalc(x, a=1.0, b=1.0):
    a = type(x)(a)
    b = type(x)(b)
    return a * x + b

assert isinstance(mycalc(1.0), float)
assert isinstance(mycalc(np.float32(1.0)), np.float32)
mycalc(np.array([1, 2, 3], dtype="float32"))  # raises TypeError: expected a sequence of integers or a single integer, got '1.0'

另外,这里有一个答案类似的问题,它使用装饰器函数来复制输入参数,这是一个好主意,但这只是为了将函数从 Numpy 数组扩展到 Pandas 系列和不适用于 Python 浮点数或 Numpy 标量。

import functools

def apply_to_pandas(func):
    @functools.wraps(func)
    def wrapper_func(x, *args, **kwargs):
        if isinstance(x, (np.ndarray, list)):
            out = func(x, *args, **kwargs)
        else:
            out = x.copy(deep=False)
            out[:] = np.apply_along_axis(func, 0, x, *args, **kwargs)
        return out
    return wrapper_func

@apply_to_pandas
def mycalc(x, a=1.0, b=1.0):
    return a * x + b

mycalc(1.0) # TypeError: copy() got an unexpected keyword argument 'deep'

这是一个类似的问题,但与数值类型或计算无关:

python pandas numpy types
1个回答
0
投票

我的意思并不是说这是一个权威的答案,而是也许值得思考一下,看看它是否有助于你走得更远。如果您尝试依赖“r”dunder 方法的更高级类型实现,这些方法似乎更加细致,并且执行了类似这样的操作:

import numpy as np
import pandas as pd

def mycalc(x, a=1, b=1):
    foo = x * a + b
    return foo if type(foo) == type(x) else type(x)(foo)

print(type(mycalc(1)))
print(type(mycalc(1.0)))
print(type(mycalc(np.float32(1.0))))
print(type(mycalc(np.array([1, 2, 3], dtype="float32"))))
print(type(mycalc(pd.Series([1, 2, 3], dtype="float64"))))

这似乎回馈:

<class 'int'>
<class 'float'>
<class 'numpy.float32'>
<class 'numpy.ndarray'>
<class 'numpy.float64'>
© www.soinside.com 2019 - 2024. All rights reserved.