Written by
Shi Yinong
on
on
区间树
在《由浅入深的聊聊广告检索》这篇文章中,聊到了构建广告索引需要用到区间树, 本文就来具体聊聊区间树这种数据结构。
问题引出
考虑这样一个问题:
给定N个不相交的区间,如: [[51, 55], [20, 36], [79, 81], [10, 20] ... ]
再给定一个点p,求点p命中了哪个区间。
这个问题很好解,遍历一遍所有区间即可,如果区间i包含了p,那么区间i即为所求。时间复杂度为N。
问题升级:
给定N个不相交的区间,如: [[51, 55], [20, 36], [79, 81], [10, 20] ... ]
再给定M个点: [p1, p2, p3 ... pn],分别求出每个点命中了哪个区间。
这个问题如果沿用之前遍历所有区间的方式解答,那么时间复杂度为N*M。
注意到,所有区间均不相交,这里可以先给区间按照左边界升序排序,然后遍历M个点,对N个区间进行二分查找,查到最后一个左边界小于等于当前点的区间。 如果该区间的右边界大于等于当前点,该区间即为所求,否则当前点不存在包含它的区间。这种解法,时间复杂度为M*log(N)。
问题再次升级:
给定N个 可能相交 的区间,如: [[21, 55], [20, 36], [79, 81], [70, 90] ... ]
再给定M个点: [p1, p2, p3 ... pn],分别求出每个点命中了 哪几个 区间。
因为区间可能相交,因此不能再使用上面那种二分查找的方法来解了,这时,就需要引入一种新的数据结构——区间树。
区间树的定义
区间树是一颗高度平衡二叉树,它的每个节点存储一个区间,并且还存储以该节点为根节点的子树中,右边界的最大值。 节点的左边界大于等于左子树每个节点的左边界,小于等于右子树每个节点的左边界。使用区间树查询点落在哪些区间,时间复杂度为k*log(N), k为点命中的区间个数,N为区间的总数。
go代码实现
这里只给出构造区间树,以及使用区间树进行查询的方法,没有给出动态对区间树进行增删节点的方法。
package interval_tree
import "sort"
type Node struct {
Interval [2]int64
MaxUpper int64
Left *Node
Right *Node
}
// Build 构造区间树
func Build(intervals [][2]int64) *Node {
// 左边界升序排序
sort.Slice(intervals, func(i, j int) bool {
return intervals[i][0] < intervals[j][0]
})
return build(intervals)
}
func build(intervals [][2]int64) *Node {
if len(intervals) == 0 {
return nil
}
mid := len(intervals) / 2
leftChild := build(intervals[:mid])
rightChild := build(intervals[mid+1:])
var leftMaxUpper, rightMaxUpper int64
if leftChild != nil {
leftMaxUpper = leftChild.MaxUpper
}
if rightChild != nil {
rightMaxUpper = rightChild.MaxUpper
}
return &Node{
Left: leftChild,
Right: rightChild,
MaxUpper: max(intervals[mid][1], max(leftMaxUpper, rightMaxUpper)),
Interval: intervals[mid],
}
}
func Search(root *Node, p int64) []*Node {
res := []*Node{}
var search func(nd *Node)
search = func(nd *Node) {
if nd == nil || nd.MaxUpper < p {
return
}
if nd.Interval[0] <= p && nd.Interval[1] >= p {
res = append(res, nd)
}
// 左子树是肯定要查的
search(nd.Left)
// 如果p小于当前节点的左边界,那么自然小于右子树的每个节点的左边界,右子树就不需要查了
// 否则就需要查。
if nd.Interval[0] <= p {
search(nd.Right)
}
}
search(root)
return res
}
func max(i, j int64) int64 {
if i > j {
return i
}
return j
}