查找 3D 数组中沿一个轴的相同值序列的长度(与游程长度编码相关)

问题描述 投票:0回答:1

我有兴趣在多维数组中查找沿单轴的 1 序列的长度。

对于一维数组,我已经使用这个旧问题的答案找到了解决方案。例如。 [0,1,0,0,1,1,1,0,1,1] --> [南,1,南,南,3,南,南,南,2,南]

对于 3D 数组,我当然可以创建一个循环,但我不愿意。 (背景气候科学,循环所有纬度/经度网格单元将使这变得非常慢。)

我正在尝试找到符合一维解决方案的解决方案。根据一维代码,我们将不胜感激,但当然也欢迎完整的不同解决方案。

作为参考,这是我的工作一维解决方案:

import xarray as xr
import numpy as np

def run_lengths(da):
    n = len(da)
    y = da.values[1:] != da.values[:-1]
    i = np.append(np.where(y), n - 1)
    z = np.diff(np.append(-1,i))
    p = np.cumsum(np.append(0,z))[:-1] 
    runs = np.where(da[i]==1)[0]
    runs_len = z[runs]           # length of sequence
    time_val = da.time[p[runs]]  # date of first day in sequence
    da_runs = xr.DataArray(runs_len,coords={'time':time_val})
    _,da_runs = xr.align(da,da_runs,join='outer') # make sure we have full time axis
    return da_runs

da = xr.DataArray(np.array([[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]],[[0,1,1,0,0,0],[1,0,1,1,0,1],[1,1,1,1,0,1]]]),coords={'lat':[0,1],'lon':[0,1,2],'time':[0,1,2,3,4,5]})
da_runs = run_lengths(da[0,1])
print(da_runs)

<xarray.DataArray (time: 6)>
array([ 1., nan,  2., nan, nan,  1.])
Coordinates:
  * time     (time) int64 0 1 2 3 4 5

这就是3D方面的尝试。我陷入了如何将

i
中的有效条目移到前面/从
i
中删除 NaN 的问题。 (也许还不止于此?)

def run_lengths_3D(da):
    n = len(da.time)
    y = da.values[:,:,1:] != da.values[:,:,:-1]
    y = xr.DataArray(y,coords={'lat':da.lat,'lon':da.lon,'time':da.time[0:-1]}) 
    i = y.where(y)*xr.DataArray(np.arange(0,len(da.time[0:-1])),coords={'time':y.time}) -1
python numpy python-xarray run-length-encoding
1个回答
0
投票

对于此任务,您可以尝试使用 ,例如:

import numba
import numpy as np


@numba.njit
def calculate(arr):
    out = np.empty_like(arr, dtype="uint16")

    lat, lon, time = arr.shape

    for i in range(lat):
        for j in range(lon):
            curr = 0
            for k in range(time):
                out[i, j, k] = 0
                if arr[i, j, k] == 0:
                    if curr > 0:
                        out[i, j, k - curr] = curr
                    curr = 0
                else:
                    curr += 1
            if curr > 0:
                out[i, j, -curr] = curr

    return out

arr = np.array(
    [
        [[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
        [[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
    ]
)
print(calculate(arr))

打印:

[[[0 2 0 0 0 0]
  [1 0 2 0 0 1]
  [4 0 0 0 0 1]]

 [[0 2 0 0 0 0]
  [1 0 2 0 0 1]
  [4 0 0 0 0 1]]]

使用

timeit
+ 并行版本进行基准测试:

from timeit import timeit

import numba
import numpy as np


@numba.njit
def calculate(arr):
    out = np.empty_like(arr, dtype="uint16")

    lat, lon, time = arr.shape

    for i in range(lat):
        for j in range(lon):
            curr = 0
            for k in range(time):
                out[i, j, k] = 0
                if arr[i, j, k] == 0:
                    if curr > 0:
                        out[i, j, k - curr] = curr
                    curr = 0
                else:
                    curr += 1
            if curr > 0:
                out[i, j, -curr] = curr

    return out


@numba.njit(parallel=True)
def calculate_parallel(arr):
    out = np.empty_like(arr, dtype="uint16")

    lat, lon, time = arr.shape

    for i in range(lat):
        for j in numba.prange(lon):
            curr = 0
            for k in range(time):
                out[i, j, k] = 0
                if arr[i, j, k] == 0:
                    if curr > 0:
                        out[i, j, k - curr] = curr
                    curr = 0
                else:
                    curr += 1
            if curr > 0:
                out[i, j, -curr] = curr

    return out


arr = np.array(
    [
        [[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
        [[0, 1, 1, 0, 0, 0], [1, 0, 1, 1, 0, 1], [1, 1, 1, 1, 0, 1]],
    ],
    dtype="uint8",
)

# compile calculate()/calculate_parallel()
assert np.allclose(calculate(arr), calculate_parallel(arr))

np.random.seed(42)
arr = np.random.randint(low=0, high=2, size=(256, 512, 3650), dtype="uint8")

t_serial = timeit("calculate(arr)", number=1, globals=globals())
t_parallel = timeit("calculate_parallel(arr)", number=1, globals=globals())

print(f"{t_serial * 1_000_000:.2f} usec")
print(f"{t_parallel * 1_000_000:.2f} usec")

在我的机器上打印(AMD 5700x):

1575227.47 usec
320453.57 usec
© www.soinside.com 2019 - 2024. All rights reserved.