递归线段树比迭代线段树需要更多空间吗?

问题描述 投票:0回答:1

我正在学习线段树数据结构。

我见过几个只使用 2n 空间的迭代线段树。因此,我尝试在具有递归更新和 sumRange 的线段树中使用相同的构建方法。这是不允许的吗?为什么迭代的seg树可以存储在2n中,而递归的seg树需要4n?或者我的非工作树中是否存在实施缺陷?

对于我的 2n 树,我使用的是 1 索引树,因此

tree[0]
中没有存储任何内容。这意味着根位于
tree[1]
。我使用初始范围 1 到 n - 1 进行递归调用,对此我不确定。当我让它转到 self.n 或从 0 开始时,我会得到不同的错误答案。如果我传入索引+1、左+1或右+1,我也会得到不同的错误答案

这是我的实现:

class NumArray:
    # Classic Segment Tree

    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * self.n * 2
        self.build(nums)

    def build(self, nums):
        # leaves
        for i in range(self.n):
            self.tree[i + self.n] = nums[i]

        # internal
        for i in range(self.n - 1, 0, -1):
            self.tree[i] = self.tree[i * 2] + self.tree[i * 2 + 1]

    def merge(self, left, right):
        return left + right

    def _update(self, tree_idx, seg_left, seg_right, i, val):
        # leaf
        if seg_left == seg_right:
            self.tree[tree_idx] = val
            return

        mid = (seg_left + seg_right) // 2
        if i > mid:
            self._update(tree_idx * 2 + 1, mid + 1, seg_right, i, val)
        else:
            self._update(tree_idx * 2, seg_left, mid, i, val)

        self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2], self.tree[tree_idx * 2 + 1])

    def update(self, index: int, val: int) -> None:
        self._update(1, 1, self.n - 1, index, val)

    def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
        # segment out of query bounds
        if seg_left > query_right or seg_right < query_left:
            return 0

        # segment fully in bounds
        if seg_left >= query_left and seg_right <= query_right:
            return self.tree[tree_idx]

        # segment partially in bounds
        mid = (seg_left + seg_right) // 2

        # this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
        if query_left > mid:
            return self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)
        elif query_right <= mid:
            return self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)

        left_sum = self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)
        right_sum = self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)

        return self.merge(left_sum, right_sum)

    def sumRange(self, left: int, right: int) -> int:
        return self._sumRange(1, 1, self.n - 1, left, right)

我确实有一个完全递归的 0 索引版本,但它使用了双倍的空间

class NumArray:
    # Classic Segment Tree
    # 0-indexed recursive

    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * self.n * 4
        self.build(nums, 0, 0, self.n - 1)

    def build(self, nums, tree_idx, left, right):
        # leaf
        if left == right:
            self.tree[tree_idx] = nums[left]
            return

        mid = (left + right) // 2
        self.build(nums, tree_idx * 2 + 1, left, mid)
        self.build(nums, tree_idx * 2 + 2, mid + 1, right)

        self.tree[tree_idx] = self.tree[tree_idx * 2 + 1] + self.tree[tree_idx * 2 + 2]

    def merge(self, left, right):
        return left + right

    def _update(self, tree_idx, seg_left, seg_right, i, val):
        # leaf
        if seg_left == seg_right:
            self.tree[tree_idx] = val
            return

        mid = (seg_left + seg_right) // 2
        if i > mid:
            self._update(tree_idx * 2 + 2, mid + 1, seg_right, i, val)
        else:
            self._update(tree_idx * 2 + 1, seg_left, mid, i, val)

        self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2 + 1], self.tree[tree_idx * 2 + 2])

    def update(self, index: int, val: int) -> None:
        self._update(0, 0, self.n - 1, index, val)

    def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
        # segment out of query bounds
        if seg_left > query_right or seg_right < query_left:
            return 0

        # segment fully in bounds
        if seg_left >= query_left and seg_right <= query_right:
            return self.tree[tree_idx]

        # segment partially in bounds
        mid = (seg_left + seg_right) // 2

        # this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
        if query_left > mid:
            return self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)
        elif query_right <= mid:
            return self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)

        left_sum = self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)
        right_sum = self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)

        return self.merge(left_sum, right_sum)

    def sumRange(self, left: int, right: int) -> int:
        return self._sumRange(0, 0, self.n - 1, left, right)

这个网站验证线段树实现是否正确

我还知道递归使用更多的调用堆栈空间。这不是我要问的问题

python algorithm data-structures segment-tree
1个回答
0
投票

事实上,很容易修改标准递归线段树以使用 2N 而不是 4N 节点。分配 4N 个节点的数组是很常见的,因为它很简单并且不需要太多思考,但可能会存在大量未使用的浪费空间。

注意,线段树是具有 N 个叶子的满二叉树,因此必须有 N - 1 个内部节点(这很容易通过归纳看出)。因此,我们实际上只需要 2N - 1 个节点,并且我们可以通过更改存储节点的顺序来优化内存使用。

在任何节点之后,我们可以存储其整个左子树,然后将整个右子树存储在表示线段树的数组中的连续位置中。对于索引

i
处的节点,其左子节点直接跟随在索引
i + 1
处的节点。它的右子节点直接跟随左子树的最后一个元素。左子树也是满二叉树,它代表范围
[left, mid]
,所以有
mid - left + 1
个叶子节点。因此,左子树中总共有
2 * (mid - left + 1) - 1
个节点。我们加 1 移动到右子节点,因此右子节点位于索引
i + 2 * (mid - left + 1)
。考虑到这一点,线段树的其余操作保持不变。

class NumArray:
    def __init__(self, nums: List[int]):
        self.n = len(nums)
        self.tree = [0] * 2 * len(nums)
        self.build(nums, 0, 0, self.n - 1)

    def merge(self, left_val, right_val):
        return left_val + right_val

    def build(self, nums, tree_idx, left, right):
        if left == right:
            self.tree[tree_idx] = nums[left]
        else:
            mid = left + right >> 1
            self.tree[tree_idx] = self.merge(self.build(nums, tree_idx + 1, left, mid), 
                self.build(nums, tree_idx + 2 * (mid - left + 1), mid + 1, right))
        return self.tree[tree_idx]

    def _update(self, tree_idx, left, right, i, val):
        if left == right:
            self.tree[tree_idx] = val
        else:
            mid = left + right >> 1
            if i > mid:
                self._update(tree_idx + 2 * (mid - left + 1), mid + 1, right, i, val)
            else:
                self._update(tree_idx + 1, left, mid, i, val)
            self.tree[tree_idx] = self.merge(self.tree[tree_idx + 1], self.tree[tree_idx + 2 * (mid - left + 1)])

    def update(self, index: int, val: int) -> None:
        return self._update(0, 0, self.n - 1, index, val)

    def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
        if query_left > query_right:
            return 0
        if seg_left == query_left and seg_right == query_right:
            return self.tree[tree_idx]
        mid = seg_left + seg_right >> 1
        return self.merge(self._sumRange(tree_idx + 1, seg_left, mid, query_left, min(mid, query_right)),
            self._sumRange(tree_idx + 2 * (mid - seg_left + 1), mid + 1, seg_right, max(mid + 1, query_left), query_right))

    def sumRange(self, left: int, right: int) -> int:
        return self._sumRange(0, 0, self.n - 1, left, right)
© www.soinside.com 2019 - 2024. All rights reserved.