我有兴趣在多维数组中查找沿单轴的 1 序列的长度。
对于一维数组,我已经使用这个旧问题的答案找到了解决方案。例如。 [0,1,0,0,1,1,1,0,1,1] --> [南,1,南,南,3,南,南,南,2,南]
对于 3D 数组,我当然可以创建一个循环,但我不愿意。 (背景气候科学,循环所有纬度/经度网格单元将使这变得非常慢。)
我正在尝试找到符合一维解决方案的解决方案。根据一维代码,我们将不胜感激,但当然也欢迎完整的不同解决方案。
作为参考,这是我的工作一维解决方案:
import xarray as xr
import numpy as np
def run_lengths(da):
n = len(da)
y = da.values[1:] != da.values[:-1]
i = np.append(np.where(y), n - 1)
z = np.diff(np.append(-1,i))
p = np.cumsum(np.append(0,z))[:-1]
runs = np.where(da[i]==1)[0]
runs_len = z[runs] # length of sequence
time_val = da.time[p[runs]] # date of first day in sequence
da_runs = xr.DataArray(runs_len,coords={'time':time_val})
_,da_runs = xr.align(da,da_runs,join='outer') # make sure we have full time axis
return da_runs
da = xr.DataArray(np.array([[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]],[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]]]),coords={'lat':[0,1],'lon':[0,1,2],'time':[0,1,2,3,4,5]})
da_runs = run_lengths(da[0,1])
print(da_runs)
<xarray.DataArray (time: 6)>
array([ 1., nan, 2., nan, nan, 1.])
Coordinates:
* time (time) int64 0 1 2 3 4 5
这就是3D方面的尝试。我陷入了如何将
i
中的有效条目移到前面/从 i
中删除 NaN 的问题。 (也许还不止于此?)
def run_lengths_3D(da):
n = len(da.time)
y = da.values[:,:,1:] != da.values[:,:,:-1]
y = xr.DataArray(y,coords={'lat':da.lat,'lon':da.lon,'time':da.time[0:-1]})
i = y.where(y)*xr.DataArray(np.arange(0,len(da.time[0:-1])),coords={'time':y.time}) -1
对于此任务,您可以尝试使用 numba,例如:
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
]
)
print(calculate(arr))
打印:
[[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]
[[0 2 0 0 0 0]
[1 0 2 0 0 1]
[4 0 0 0 0 1]]]
使用
timeit
+ 并行版本进行基准测试:
from timeit import timeit
import numba
import numpy as np
@numba.njit
def calculate(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in range(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
@numba.njit(parallel=True)
def calculate_parallel(arr):
out = np.empty_like(arr, dtype="uint16")
lat, lon, time = arr.shape
for i in range(lat):
for j in numba.prange(lon):
curr = 0
for k in range(time):
out[i, j, k] = 0
if arr[i, j, k] == 0:
if curr > 0:
out[i, j, k - curr] = curr
curr = 0
else:
curr += 1
if curr > 0:
out[i, j, -curr] = curr
return out
arr = np.array(
[
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
[[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
],
dtype="uint8",
)
# compile calculate()/calculate_parallel()
assert np.allclose(calculate(arr), calculate_parallel(arr))
np.random.seed(42)
arr = np.random.randint(low=0, high=2, size=(256, 512, 3650), dtype="uint8")
t_serial = timeit("calculate(arr)", number=1, globals=globals())
t_parallel = timeit("calculate_parallel(arr)", number=1, globals=globals())
print(f"{t_serial * 1_000_000:.2f} usec")
print(f"{t_parallel * 1_000_000:.2f} usec")
在我的机器上打印(AMD 5700x):
1575227.47 usec
320453.57 usec