数据结构的扩张
由于数据结构理论已经很成熟,标准数据结构在性能与抽象上往往没有太大改动空间。所以在需要更多功能的数据结构时,往往只需找到相关的标准数据结构作为基础,然后在其上存储额外变量、构造额外运算。这样的流程就称为“数据结构的扩张”,数据结构的扩张往往不是一个单纯的数据结构与算法的问题,其还包含了一定的程序设计、面向对象设计方面的考量。下文将以广泛作为标准结构的红黑树为例,研究2种红黑树的扩张。数据结构的扩张可分为如下步骤:
1) 根据问题选择标准数据结构作为基础结构;
2) 确定基础结构上需额外维护的附加信息;
3) 验证基础结构上的运算能否维护附加信息;
4) 构造新的运算;
下面是之前写过的红黑树代码,用它作为面向对象编程的父类/基类,之后会用面向对象的继承来实现红黑树的扩张。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | class Nil: nil = True red = False def __init__(self): self.parent = None @property def brother(self): if self == self.parent.left: return self.parent.right elif self == self.parent.right: return self.parent.left raise Exception def discolor(self): self.red = False if self.red else True class Node(Nil): nil = False def __init__(self,key): self.key = key self.red = True self.left = Nil() self.left.parent = self self.right = Nil() self.right.parent = self self.parent = Nil() class RedBlackTree: def __init__(self): self.root = Nil() def _leftRotate(self,node): rightNode = node.right node.right = rightNode.left if not rightNode.left.nil: rightNode.left.parent = node rightNode.parent = node.parent if node.parent.nil: self.root = rightNode elif node == node.parent.left: node.parent.left = rightNode elif node == node.parent.right: node.parent.right = rightNode rightNode.left = node node.parent = rightNode def _rightRotate(self,node): leftNode = node.left node.left = leftNode.right if not leftNode.right.nil: leftNode.right.parent = node leftNode.parent = node.parent if node.parent.nil: self.root = leftNode elif node.parent.left == node: node.parent.left = leftNode elif node.parent.right == node: node.parent.right = leftNode leftNode.right = node node.parent = leftNode def _bstInsert(self,key): nodeToInsert = Node(key) if self.root.nil: self.root = nodeToInsert return nodeToInsert node = self.root while not node.nil: if key < node.key: if not node.left.nil: node = node.left else: node.left = nodeToInsert node.left.parent = node return nodeToInsert else: if not node.right.nil: node = node.right else: node.right = nodeToInsert node.right.parent = node return nodeToInsert def insert(self,key): node = self._bstInsert(key) while True: if node.parent.nil: node.discolor() return if not node.parent.red: return if node.parent.brother.red: node.parent.brother.discolor() node.parent.discolor() node.parent.parent.discolor() node = node.parent.parent else: if node.parent == node.parent.parent.left: if node.parent.left == node: node.parent.discolor() node.parent.parent.discolor() self._rightRotate(node.parent.parent) else: node.discolor() node.parent.parent.discolor() self._leftRotate(node.parent) self._rightRotate(node.parent) else: if node.parent.right == node: node.parent.discolor() node.parent.parent.discolor() self._leftRotate(node.parent.parent) else: node.discolor() node.parent.parent.discolor() self._rightRotate(node.parent) self._leftRotate(node.parent) return def _bstFind(self,key): node = self.root while not node.nil: if key < node.key: node = node.left elif key > node.key: node = node.right else: return node return False def find(self,key): return True if self._bstFind(key) else False def _bstRemove(self,key): node = self._bstFind(key) if not node: return False if node.left.nil and node.right.nil: nil = Nil() if node.parent.nil: self.root = nil elif node.parent.left == node: node.parent.left = nil else: node.parent.right = nil nil.parent = node.parent return node,nil elif not node.right.nil and node.left.nil: if node.parent.nil: self.root = node.right self.root.parent = Nil() elif node.parent.left == node: node.parent.left = node.right node.right.parent = node.parent else: node.parent.right = node.right node.right.parent = node.parent return node,node.right elif not node.left.nil and node.right.nil: if node.parent.nil: self.root = node.left self.root.parent = Nil() elif node.parent.left == node: node.parent.left = node.left node.left.parent = node.parent else: node.parent.right = node.left node.left.parent = node.parent return node,node.left else: _node = node parent = node node = node.right while not node.left.nil : parent = node node = node.left _node.key = node.key if node == parent.right: parent.right = node.right else: parent.left = node.right node.right.parent = parent return node,node.right def remove(self,key): results = self._bstRemove(key) if not results: return False deleted,node = results if deleted.red: return True if node.red: node.red = False return True while True: if node.parent.nil: return True if node.parent.left == node: if node.brother.red: node.parent.discolor() node.brother.discolor() self._leftRotate(node.parent) elif not(node.brother.left.red) and not(node.brother.right.red): node.brother.discolor() if node.parent.red: node.parent.discolor() return True else: node = node.parent elif not(node.brother.right.red): node.brother.discolor() node.brother.left.discolor() self._rightRotate(node.brother) else: node.brother.right.discolor() node.brother.red = node.parent.red node.parent.red = False self._leftRotate(node.parent) return True else: if node.brother.red: node.parent.discolor() node.brother.discolor() self._rightRotate(node.parent) elif not(node.brother.left.red) and not(node.brother.right.red): node.brother.discolor() if node.parent.red: node.parent.discolor() return True else: node = node.parent elif not(node.brother.left.red): node.brother.discolor() node.brother.right.discolor() self._leftRotate(node.brother) else: node.brother.left.discolor() node.brother.red = node.parent.red node.parent.red = False self._rightRotate(node.parent) return True |
动态顺序统计树
求乱序序列中查找第\(k\)大元素被称为“TopK问题”或“选择问题”,动态顺序统计树是红黑树的扩张,其使得红黑树支持查找第\(k\)大元素,且其最坏代价为\(O(lgn)\),实现该运算的步骤如下:
1) 在节点增设\(size\)变量记录该子树的总节点数,任意节点\(x\)满足\(x.size=x.left.size + 1 + x.right.size\);
2) 利用\(size\)可递归得到排序,根节点左子树含排序为\( [1…x.left.size]\)的节点,根节点的排序为\(x.left.size+1\),根节点右子树包含排序为\( [1+x.left.size…x.size]\)的节点…
3) 在红黑树的增删运算与旋转运算中,需要额外维护\(size\)的正确性;
4) 实现\(select\)运算结构查找第\(k\)大节点的\(key\);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | class OrderStatisticTree(RedBlackTree): def _leftRotate(self,node): super()._leftRotate(node) node = node.parent s1 = 1 if not node.left.left.nil: s1 += node.left.left.size if not node.left.right.nil: s1 += node.left.right.size node.left.size = s1 s2 = 1 + s1 if not node.right.nil: s2 += node.right.size node.size = s2 def _rightRotate(self,node): super()._rightRotate(node) node = node.parent s1 = 1 if not node.right.left.nil: s1 += node.right.left.size if not node.right.right.nil: s1 += node.right.right.size node.right.size = s1 s2 = 1 + s1 if not node.left.nil: s2 += node.left.size node.size = s2 def _bstInsert(self,key): node = super()._bstInsert(key) node.size = 1 x = node while not node.parent.nil: node.parent.size += 1 node = node.parent return x def _bstRemove(self,key): results = super()._bstRemove(key) if results: node = results[0] while not node.parent.nil: node.parent.size -= 1 node = node.parent return results def _select(self,node,rank): if not node.nil: if node.left.nil: return node if rank == 1 else self._select(node.right,rank-1) if rank <= node.left.size: return self._select(node.left,rank) if rank == node.left.size + 1: return node if rank <= node.left.size + 1 + node.right.size: return self._select(node.right,rank-1-node.left.size) return False def select(self,rank): node = self._select(self.root,rank) return node.key if node else False |
测试数据:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | def checkBST(root): if not root.nil: if not root.left.nil: assert root.key >= root.left.key if not root.right.nil: assert root.key <= root.right.key checkBST(root.left) checkBST(root.right) def checkRedBlack(root): if not root.nil: if not root.left.nil: assert not(root.red) or (root.red and not(root.left.red)) if not root.right.nil: assert not(root.red) or (root.red and not(root.right.red)) checkRedBlack(root.left) checkRedBlack(root.right) def checkBlackHeight(root): if not root.nil: l = checkBlackHeight(root.left) r = checkBlackHeight(root.right) assert l==r return l if root.red else l+1 return 0 def checkSize(root): if not root.nil: l = checkSize(root.left) r = checkSize(root.right) s = l + r + 1 assert s == root.size return s return 0 def isBlackRoot(tree): assert not(tree.root.red) def check(tree): checkBST(tree.root) checkRedBlack(tree.root) checkBlackHeight(tree.root) checkSize(tree.root) isBlackRoot(tree) def test(size): from random import randint,shuffle tree = OrderStatisticTree() data = [ randint(0,size//5) for i in range(size) ] check(tree) tmp = [] for i in data: tree.insert(i) check(tree) tmp.append(i) tmp.sort() for i in range(len(tmp)): assert tree.select(i+1)==tmp[i] shuffle(data) tmp = data[:] tmp.sort() for i in data: tree.remove(i) check(tree) tmp.remove(i) for i in range(len(tmp)): assert tree.select(i+1)==tmp[i] test(3000) |
区间树
区间树的节点记录了一个闭区间\([low,high]\),其运算结构如下:
1) \(insert(low,high)\):插入一个闭区间\([low,high]\);
2) \(delete(low,high)\):删除一个闭区间\([low,high]\);
3) \(find(low,high)\):返回一个与闭区间\([low,high]\)有“重叠部分”的闭区间;
下面考虑把区间树作为红黑树的扩张,其步骤为:
1) 把\(low\)对应到原红黑树节点的\(key\),把\(high\)作为原红黑树节点的新变量;
2) 在原红黑树节点增加新变量\(max\)表示该子树所有节点的最大\(high\)值;
3) 原有\(insert\)的流程不变,只需额外初始化\(high\)变量,最后回溯维护\(max\);
4) 原有\(delete\)流程中的查找流程需从\(low\)重复的节点找到\(high\)符合要求的节点,最后回溯维护\(max\);
5) 当需要旋转时,由于新父节点的\(max\)显然与原父节点相同,故无需回溯更新\(max\);
6) \(find(low,high)\)的流程如下,设当前节点为\(x\)则从根节点开始执行。若\(x\)的区间与\([low,high]\)有重叠则返回,否则若\(x.left.max \geq low\)则左子树必然存在与与\([low,high]\)重叠的区间,然后记\(x = x.left\)继续循环,否则重叠区间只存在于右子树或不存在重叠区间,记\(x = x.right\)继续循环。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | class SegmentTree(RedBlackTree): def _fixMax(self,node): m = node.high m = m if node.left.nil else max(m,node.left.max) m = m if node.right.nil else max(m,node.right.max) node.max = m def _overlap(self,seg1,seg2): return seg2[1]<=seg1[0] or seg2[0]<=seg1[1] def _leftRotate(self,node): super()._leftRotate(node) node.parent.max = node.max self._fixMax(node) def _rightRotate(self,node): super()._rightRotate(node) node.parent.max = node.max self._fixMax(node) def _bstInsert(self,seg): low,high = seg x = super()._bstInsert(low) x.high = high x.max = high y = x while not y.parent.nil: self._fixMax(y.parent) y = y.parent return x def _bstFind(self,seg): low,high = seg node = self.root while not node.nil: if low < node.key: node = node.left elif low > node.key: node = node.right elif high == node.high: return node elif not node.left.nil and low == node.left.key: node = node.left elif not node.right.nil and low == node.right.key: node = node.right else: return False return False def _bstRemove2(self,seg): node = self._bstFind(seg) if not node: return False if node.left.nil and node.right.nil: nil = Nil() if node.parent.nil: self.root = nil elif node.parent.left == node: node.parent.left = nil else: node.parent.right = nil nil.parent = node.parent return node,nil elif not node.right.nil and node.left.nil: if node.parent.nil: self.root = node.right self.root.parent = Nil() elif node.parent.left == node: node.parent.left = node.right node.right.parent = node.parent else: node.parent.right = node.right node.right.parent = node.parent return node,node.right elif not node.left.nil and node.right.nil: if node.parent.nil: self.root = node.left self.root.parent = Nil() elif node.parent.left == node: node.parent.left = node.left node.left.parent = node.parent else: node.parent.right = node.left node.left.parent = node.parent return node,node.left else: _node = node parent = node node = node.right while not node.left.nil : parent = node node = node.left _node.key = node.key _node.high = node.high _node.max = node.max if node == parent.right: parent.right = node.right else: parent.left = node.right node.right.parent = parent return node,node.right def _bstRemove(self,seg): results = self._bstRemove2(seg) if results: y = results[0] while not y.parent.nil: self._fixMax(y.parent) y = y.parent return results def find(self,seg): x = self.root while not x.nil and not self._overlap(seg,[x.key,x.high]): if not x.left.nil and x.left.max >= seg[0]: x = x.left else: x = x.right if x.nil: return False return [x.key,x.high] |
测试数据:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | def checkBST(root): if not root.nil: if not root.left.nil: assert root.key >= root.left.key if not root.right.nil: assert root.key <= root.right.key checkBST(root.left) checkBST(root.right) def checkRedBlack(root): if not root.nil: if not root.left.nil: assert not(root.red) or (root.red and not(root.left.red)) if not root.right.nil: assert not(root.red) or (root.red and not(root.right.red)) checkRedBlack(root.left) checkRedBlack(root.right) def checkBlackHeight(root): if not root.nil: l = checkBlackHeight(root.left) r = checkBlackHeight(root.right) assert l==r return l if root.red else l+1 return 0 def checkMax(root): if not root.nil: l = checkMax(root.left) if l == False: l = root.high - 1 r = checkMax(root.right) if r == False: r = root.high - 1 assert root.max == max(l,r,root.high) return root.max return False def check(tree): assert not(tree.root.red) checkBST(tree.root) checkRedBlack(tree.root) checkBlackHeight(tree.root) checkMax(tree.root) def test(size): from random import randint,shuffle tree = SegmentTree() check(tree) left = [ randint(0,size) for i in range(size)] right = [ randint(i,size) for i in left] data = [ [left[i],right[i]] for i in range(size) ] for i in data: tree.insert(i) check(tree) shuffle(data) for i in range(0,size): for j in range(i,size): r = tree.find([i,j]) if r: x,y = r assert x<=j or i<=y else: for k in range(len(left)): l,r = left[k],right[k] assert i<=r or l<=j shuffle(data) for i in data: tree.remove(i) check(tree) test(3000) |