我正在学习线段树数据结构。
我见过几个只使用 2n 空间的迭代线段树。因此,我尝试在具有递归更新和 sumRange 的线段树中使用相同的构建方法。这是不允许的吗?为什么迭代的seg树可以存储在2n中,而递归的seg树需要4n?或者我的非工作树中是否存在实施缺陷?
对于我的 2n 树,我使用的是 1 索引树,因此
tree[0]
中没有存储任何内容。这意味着根位于tree[1]
。我使用初始范围 1 到 n - 1 进行递归调用,对此我不确定。当我让它转到 self.n 或从 0 开始时,我会得到不同的错误答案。如果我传入索引+1、左+1或右+1,我也会得到不同的错误答案
这是我的实现:
class NumArray:
# Classic Segment Tree
def __init__(self, nums: List[int]):
self.n = len(nums)
self.tree = [0] * self.n * 2
self.build(nums)
def build(self, nums):
# leaves
for i in range(self.n):
self.tree[i + self.n] = nums[i]
# internal
for i in range(self.n - 1, 0, -1):
self.tree[i] = self.tree[i * 2] + self.tree[i * 2 + 1]
def merge(self, left, right):
return left + right
def _update(self, tree_idx, seg_left, seg_right, i, val):
# leaf
if seg_left == seg_right:
self.tree[tree_idx] = val
return
mid = (seg_left + seg_right) // 2
if i > mid:
self._update(tree_idx * 2 + 1, mid + 1, seg_right, i, val)
else:
self._update(tree_idx * 2, seg_left, mid, i, val)
self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2], self.tree[tree_idx * 2 + 1])
def update(self, index: int, val: int) -> None:
self._update(1, 1, self.n - 1, index, val)
def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
# segment out of query bounds
if seg_left > query_right or seg_right < query_left:
return 0
# segment fully in bounds
if seg_left >= query_left and seg_right <= query_right:
return self.tree[tree_idx]
# segment partially in bounds
mid = (seg_left + seg_right) // 2
# this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
if query_left > mid:
return self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)
elif query_right <= mid:
return self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)
left_sum = self._sumRange(tree_idx * 2, seg_left, mid, query_left, query_right)
right_sum = self._sumRange(tree_idx * 2 + 1, mid + 1, seg_right, query_left, query_right)
return self.merge(left_sum, right_sum)
def sumRange(self, left: int, right: int) -> int:
return self._sumRange(1, 1, self.n - 1, left, right)
我确实有一个完全递归的 0 索引版本,但它使用了双倍的空间
class NumArray:
# Classic Segment Tree
# 0-indexed recursive
def __init__(self, nums: List[int]):
self.n = len(nums)
self.tree = [0] * self.n * 4
self.build(nums, 0, 0, self.n - 1)
def build(self, nums, tree_idx, left, right):
# leaf
if left == right:
self.tree[tree_idx] = nums[left]
return
mid = (left + right) // 2
self.build(nums, tree_idx * 2 + 1, left, mid)
self.build(nums, tree_idx * 2 + 2, mid + 1, right)
self.tree[tree_idx] = self.tree[tree_idx * 2 + 1] + self.tree[tree_idx * 2 + 2]
def merge(self, left, right):
return left + right
def _update(self, tree_idx, seg_left, seg_right, i, val):
# leaf
if seg_left == seg_right:
self.tree[tree_idx] = val
return
mid = (seg_left + seg_right) // 2
if i > mid:
self._update(tree_idx * 2 + 2, mid + 1, seg_right, i, val)
else:
self._update(tree_idx * 2 + 1, seg_left, mid, i, val)
self.tree[tree_idx] = self.merge(self.tree[tree_idx * 2 + 1], self.tree[tree_idx * 2 + 2])
def update(self, index: int, val: int) -> None:
self._update(0, 0, self.n - 1, index, val)
def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
# segment out of query bounds
if seg_left > query_right or seg_right < query_left:
return 0
# segment fully in bounds
if seg_left >= query_left and seg_right <= query_right:
return self.tree[tree_idx]
# segment partially in bounds
mid = (seg_left + seg_right) // 2
# this is not necessary for correctness, but helps with efficiency (we only go down 1 path if 2 is unnecessary)
if query_left > mid:
return self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)
elif query_right <= mid:
return self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)
left_sum = self._sumRange(tree_idx * 2 + 1, seg_left, mid, query_left, query_right)
right_sum = self._sumRange(tree_idx * 2 + 2, mid + 1, seg_right, query_left, query_right)
return self.merge(left_sum, right_sum)
def sumRange(self, left: int, right: int) -> int:
return self._sumRange(0, 0, self.n - 1, left, right)
这个网站验证线段树实现是否正确
我还知道递归使用更多的调用堆栈空间。这不是我要问的问题
事实上,很容易修改标准递归线段树以使用 2N 而不是 4N 节点。分配 4N 个节点的数组是很常见的,因为它很简单并且不需要太多思考,但可能会存在大量未使用的浪费空间。
注意,线段树是具有 N 个叶子的满二叉树,因此必须有 N - 1 个内部节点(这很容易通过归纳看出)。因此,我们实际上只需要 2N - 1 个节点,并且我们可以通过更改存储节点的顺序来优化内存使用。
在任何节点之后,我们可以存储其整个左子树,然后将整个右子树存储在表示线段树的数组中的连续位置中。对于索引
i
处的节点,其左子节点直接跟随在索引 i + 1
处的节点。它的右子节点直接跟随左子树的最后一个元素。左子树也是满二叉树,它代表范围[left, mid]
,所以有mid - left + 1
个叶子节点。因此,左子树中总共有 2 * (mid - left + 1) - 1
个节点。我们加 1 移动到右子节点,因此右子节点位于索引 i + 2 * (mid - left + 1)
。考虑到这一点,线段树的其余操作保持不变。
class NumArray:
def __init__(self, nums: List[int]):
self.n = len(nums)
self.tree = [0] * 2 * len(nums)
self.build(nums, 0, 0, self.n - 1)
def merge(self, left_val, right_val):
return left_val + right_val
def build(self, nums, tree_idx, left, right):
if left == right:
self.tree[tree_idx] = nums[left]
else:
mid = left + right >> 1
self.tree[tree_idx] = self.merge(self.build(nums, tree_idx + 1, left, mid),
self.build(nums, tree_idx + 2 * (mid - left + 1), mid + 1, right))
return self.tree[tree_idx]
def _update(self, tree_idx, left, right, i, val):
if left == right:
self.tree[tree_idx] = val
else:
mid = left + right >> 1
if i > mid:
self._update(tree_idx + 2 * (mid - left + 1), mid + 1, right, i, val)
else:
self._update(tree_idx + 1, left, mid, i, val)
self.tree[tree_idx] = self.merge(self.tree[tree_idx + 1], self.tree[tree_idx + 2 * (mid - left + 1)])
def update(self, index: int, val: int) -> None:
return self._update(0, 0, self.n - 1, index, val)
def _sumRange(self, tree_idx, seg_left, seg_right, query_left, query_right):
if query_left > query_right:
return 0
if seg_left == query_left and seg_right == query_right:
return self.tree[tree_idx]
mid = seg_left + seg_right >> 1
return self.merge(self._sumRange(tree_idx + 1, seg_left, mid, query_left, min(mid, query_right)),
self._sumRange(tree_idx + 2 * (mid - seg_left + 1), mid + 1, seg_right, max(mid + 1, query_left), query_right))
def sumRange(self, left: int, right: int) -> int:
return self._sumRange(0, 0, self.n - 1, left, right)