区间树

《由浅入深的聊聊广告检索》这篇文章中,聊到了构建广告索引需要用到区间树, 本文就来具体聊聊区间树这种数据结构。

问题引出

考虑这样一个问题:

给定N个不相交的区间,如: [[51, 55], [20, 36], [79, 81], [10, 20] ... ]
再给定一个点p,求点p命中了哪个区间。

这个问题很好解,遍历一遍所有区间即可,如果区间i包含了p,那么区间i即为所求。时间复杂度为N。

问题升级:

给定N个不相交的区间,如: [[51, 55], [20, 36], [79, 81], [10, 20] ... ]
再给定M个点: [p1, p2, p3 ... pn],分别求出每个点命中了哪个区间。

这个问题如果沿用之前遍历所有区间的方式解答,那么时间复杂度为N*M。

注意到,所有区间均不相交,这里可以先给区间按照左边界升序排序,然后遍历M个点,对N个区间进行二分查找,查到最后一个左边界小于等于当前点的区间。 如果该区间的右边界大于等于当前点,该区间即为所求,否则当前点不存在包含它的区间。这种解法,时间复杂度为M*log(N)。

问题再次升级:

给定N个 可能相交 的区间,如: [[21, 55], [20, 36], [79, 81], [70, 90] ... ]
再给定M个点: [p1, p2, p3 ... pn],分别求出每个点命中了 哪几个 区间。

因为区间可能相交,因此不能再使用上面那种二分查找的方法来解了,这时,就需要引入一种新的数据结构——区间树。

区间树的定义

区间树是一颗高度平衡二叉树,它的每个节点存储一个区间,并且还存储以该节点为根节点的子树中,右边界的最大值。 节点的左边界大于等于左子树每个节点的左边界,小于等于右子树每个节点的左边界。使用区间树查询点落在哪些区间,时间复杂度为k*log(N), k为点命中的区间个数,N为区间的总数。

go代码实现

这里只给出构造区间树,以及使用区间树进行查询的方法,没有给出动态对区间树进行增删节点的方法。

package interval_tree

import "sort"

type Node struct {
	Interval [2]int64
	MaxUpper int64
	Left     *Node
	Right    *Node
}

// Build 构造区间树
func Build(intervals [][2]int64) *Node {
	// 左边界升序排序
	sort.Slice(intervals, func(i, j int) bool {
		return intervals[i][0] < intervals[j][0]
	})
	return build(intervals)
}

func build(intervals [][2]int64) *Node {
	if len(intervals) == 0 {
		return nil
	}
	mid := len(intervals) / 2
	leftChild := build(intervals[:mid])
	rightChild := build(intervals[mid+1:])
	var leftMaxUpper, rightMaxUpper int64
	if leftChild != nil {
		leftMaxUpper = leftChild.MaxUpper
	}
	if rightChild != nil {
		rightMaxUpper = rightChild.MaxUpper
	}
	return &Node{
		Left:     leftChild,
		Right:    rightChild,
		MaxUpper: max(intervals[mid][1], max(leftMaxUpper, rightMaxUpper)),
		Interval: intervals[mid],
	}
}

func Search(root *Node, p int64) []*Node {
	res := []*Node{}
	var search func(nd *Node)
	search = func(nd *Node) {
		if nd == nil || nd.MaxUpper < p {
			return
		}
		if nd.Interval[0] <= p && nd.Interval[1] >= p {
			res = append(res, nd)
		}
		// 左子树是肯定要查的
		search(nd.Left)
		// 如果p小于当前节点的左边界,那么自然小于右子树的每个节点的左边界,右子树就不需要查了
		// 否则就需要查。
		if nd.Interval[0] <= p {
			search(nd.Right)
		}
	}
	search(root)
	return res
}

func max(i, j int64) int64 {
	if i > j {
		return i
	}
	return j
}