np.arange 与数组输入

问题描述 投票:0回答:2

我希望在给定长度值存储为数组的情况下一次创建多个范围。

示例:

lengths = np.array([1, 5, 10])

创建范围:

ranges = np.arange(lengths)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

或者,我也可以使用 np.linspace 的解决方案,将步骤编号数组作为 num 参数传递。

谢谢。

arrays numpy range
2个回答
0
投票

也许您需要的只是列表理解?

lengths = np.array([1, 5, 10])
ranges = [np.arange(length) for length in lengths]

0
投票

我经常需要这个。

完全矢量化选项:

import numpy as np

IntArray = np.ndarray


def alt_cumsum(a: np.ndarray) -> np.ndarray:
    """Alternative cumsum, start at 0, omit last value."""
    out = np.empty(a.size, a.dtype)
    out[0] = 0
    np.cumsum(a[:-1], out=out[1:])
    return out


def ragged_arange(n: IntArray) -> IntArray:
    """Equal to: np.concatenate([np.arange(e) for e in n])"""
    size = int(n.sum())
    return alt_cumsum(np.ones(size, dtype=int)) - np.repeat(alt_cumsum(n), n)

对于基准测试:

def loop_ragged_arange(n: IntArray) -> IntArray:
    return np.concatenate([np.arange(e) for e in n])


@nb.njit
def numba_ragged_arange(n: IntArray) -> IntArray:
    size = int(n.sum())
    out = np.empty(size, dtype=np.int64)
    index = 0
    for end in n:
        for i in range(end):
            out[index] = i
            index += 1
    return out


# Run once for JIT and check answers
a = np.random.choice(np.arange(10), size=100)
r0 = ragged_arange(a)
r1 = loop_ragged_arange(a)
r2 = numba_ragged_arange(a)
assert np.array_equal(r0, r1)
assert np.array_equal(r0, r2)


b = perfplot.bench(
    setup=lambda n: np.random.choice(np.arange(10), size=n),
    kernels=[loop_ragged_arange, ragged_arange, numba_ragged_arange],
    labels=["loop", "vectorized", "numba"],
    n_range=[2**k for k in range(21)],
    xlabel="len(a)",
)
b.save("out.png")

enter image description here

对于 n=1000000(一百万),我得到以下结果:

In [3]: %timeit ragged_arange(a)
46.4 ms ± 430 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [4]: %timeit loop_ragged_arange(a)
684 ms ± 5.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: %timeit numba_ragged_arange(a)
13.4 ms ± 291 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

numba 版本仅分配输出,而 numpy 版本需要一些临时数组。

© www.soinside.com 2019 - 2024. All rights reserved.