我正在尝试找到一种方法来编写 PySpark UDF,它可以支持任何输入类型并根据输入类型返回类型。例如,假设我想创建一个简单的钳位函数,它只是将数值钳位在下限和上限之间。可以这样写:
def clamp(value, low, high):
return max(low, min(value, high))
这个函数应该能够支持任何数字类型。但是
pyspark.sql.functions.udf
函数允许您仅指定单个返回类型,如果类型不匹配,则 udf
返回 NULL
。我尝试使用 overload
模块中的 typing
,但我无法让它工作。我知道这一定是可能的,因为像 pyspark.sql.functions.sum
这样的函数适用于任何数字类型,但我似乎无法复制它。任何帮助将不胜感激。
经过一系列实验,我为任何可能感兴趣的人创建了一个粗略的解决方案。它并不能完全满足我的要求,但它已经很接近了,比我聪明的人可能能够对其进行改进。基本上,我创建了一个装饰器,它包装函数并在调用函数时创建 UDF,而不是创建函数。这是代码:
class variable_udf:
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, return_type=StringType(), **kwargs):
if return_type not in self.cache:
self.cache[return_type] = F.udf(self.func, return_type)
return self.cache[return_type](*args, **kwargs)
然后你可以创建你的函数并像这样调用它:
@variable_udf
def clamp(value, low, high):
return max(low, min(value, high))
df = df.select(clamp("col1", "col2", "col3", return_type=DoubleType()).alias("result"))
我还没有做过任何性能测试或任何东西,所以我不确定它的性能如何,但我希望缓存的使用至少有帮助。由于某种原因,这个特定的示例还要求所有列具有相同的类型。如果
low
和 high
是 int,但返回类型是 double,则结果为 null。不知道为什么会这样。