编写可与 float、Numpy 或 Pandas 数据类型一起使用并且始终返回与给定参数相同类型的 python 函数的最佳方法是什么。 问题是,计算包括一个或多个浮点值。
例如玩具示例:
def mycalc(x, a=1.0, b=1.0):
return a * x + b
(我在这里大大简化了问题,因为我理想情况下希望有多个输入参数,例如
x
,但是您可以假设该函数是矢量化的,因为它可以与 Numpy 数组参数和 Pandas 系列一起使用).
对于 Numpy 数组和 Pandas 系列,这工作得很好,因为数据类型是由输入参数决定的。
import numpy as np
x = np.array([1, 2, 3], dtype="float32")
print(mycalc(x).dtype) # float32
import pandas as pd
x = pd.Series([1.0, 2.0, 3.0], dtype="float32")
print(mycalc(x).dtype) # float32
但是当使用较低精度的 numpy 浮点数时,类型会“提升”为 float64,大概是由于公式中的 float 参数:
x = np.float32(1.0)
print(mycalc(x).dtype) # float64
理想情况下,我希望该函数能够使用 Python 浮点数、numpy 标量、numpy 数组、Pandas 系列、Jax 数组,如果可能的话甚至可以使用 Sympy 符号变量。
但我不想用太多额外的语句来处理每种情况,从而使函数变得混乱。
我尝试过这个,它可以与 Numpy 标量一起使用,但当您提供数组或系列时会中断:
def mycalc(x, a=1.0, b=1.0):
a = type(x)(a)
b = type(x)(b)
return a * x + b
assert isinstance(mycalc(1.0), float)
assert isinstance(mycalc(np.float32(1.0)), np.float32)
mycalc(np.array([1, 2, 3], dtype="float32")) # raises TypeError: expected a sequence of integers or a single integer, got '1.0'
另外,这里有一个答案到类似的问题,它使用装饰器函数来复制输入参数,这是一个好主意,但这只是为了将函数从 Numpy 数组扩展到 Pandas 系列和不适用于 Python 浮点数或 Numpy 标量。
import functools
def apply_to_pandas(func):
@functools.wraps(func)
def wrapper_func(x, *args, **kwargs):
if isinstance(x, (np.ndarray, list)):
out = func(x, *args, **kwargs)
else:
out = x.copy(deep=False)
out[:] = np.apply_along_axis(func, 0, x, *args, **kwargs)
return out
return wrapper_func
@apply_to_pandas
def mycalc(x, a=1.0, b=1.0):
return a * x + b
mycalc(1.0) # TypeError: copy() got an unexpected keyword argument 'deep'
这是一个类似的问题,但与数值类型或计算无关:
我的意思并不是说这是一个权威的答案,而是也许值得思考一下,看看它是否有助于你走得更远。如果您尝试依赖“r”dunder 方法的更高级类型实现,这些方法似乎更加细致,并且执行了类似这样的操作:
import numpy as np
import pandas as pd
def mycalc(x, a=1, b=1):
foo = x * a + b
return foo if type(foo) == type(x) else type(x)(foo)
print(type(mycalc(1)))
print(type(mycalc(1.0)))
print(type(mycalc(np.float32(1.0))))
print(type(mycalc(np.array([1, 2, 3], dtype="float32"))))
print(type(mycalc(pd.Series([1, 2, 3], dtype="float64"))))
这似乎回馈:
<class 'int'>
<class 'float'>
<class 'numpy.float32'>
<class 'numpy.ndarray'>
<class 'numpy.float64'>