这是一个索引选择函数
def np_ifelse(
x: np.ndarray[float] | float,
ind: bool | np.ndarray[bool],
v1: float,
v2: float
) -> float | np.ndarray[float]:
if isinstance(x, np.ndarray):
y = np.full_like(x, v2)
y[ind] = v1
return y
else:
return v1 if ind else v2
如果输入是标量,它是两个可能结果之间的简单 if-else 语句,从而产生标量输出
如果输入是数组,则 if-else 语句将单独应用于每个数组元素,结果是一个数组。
问题:这段代码对我来说看起来有点难看。是否已经有一个 numpy 函数可以更优雅地完成此操作,而无需显式类型检查?
由于
x
和 ind
的类型相同,因此可以跳过检查 x
的类型,直接将 ind
与 np.where
一起使用:
示例:
import numpy as np
# scalar input
x_scalar = 5.0
ind_scalar = True
v1 = 10.0
v2 = 20.0
result_scalar = np.where(ind_scalar, v1, v2)
print(result_scalar)
# array input
x_array = np.array([1.0, 2.0, 3.0, 4.0])
ind_array = np.array([True, False, True, False])
result_array = np.where(ind_array, v1, v2)
print(result_array)
输出:
10.0
[10. 20. 10. 20.]